diff --git a/.bazelrc b/.bazelrc index 73926e5a2f9..066b0db10bc 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,6 +5,7 @@ # Android options: # android: # android_arm: +# android_arm64: # android_x86: # android_x86_64: # @@ -18,8 +19,10 @@ # # Compiler options: # cuda_clang: Use clang when building CUDA code. -# c++17: Build with C++17 options -# c++1z: Build with C++17 options +# c++17: Build with C++17 options (links with libc++) +# c++1z: Build with C++17 options (links with libc++) +# c++17_gcc: Build with C++17 options (links with stdlibc++) +# c++1z_gcc: Build with C++17 options (links with stdlibc++) # avx_linux: Build with avx instruction set on linux. # avx2_linux: Build with avx2 instruction set on linux. # native_arch_linux: Build with instruction sets available to the host machine on linux @@ -44,10 +47,6 @@ # using_cuda: CUDA is available to build system. # cuda: Build with full cuda support. # rocm: Build with AMD GPU support (rocm). -# sycl: Build with SYCL support. -# sycl_nodouble: -# sycl_asan: -# sycl_trisycl: # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. # ngraph: Enable ngraph support. @@ -87,6 +86,7 @@ # release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. # release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. # release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. +# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds. # release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. # release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. @@ -159,13 +159,11 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt # config to build OneDNN backend with a user specified threadpool. build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true build:mkl_threadpool --define=build_with_mkl_opensource=true build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool -c opt @@ -173,7 +171,6 @@ build:mkl_threadpool -c opt # Config setting to build with oneDNN and without the binary blob build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only -c opt @@ -214,19 +211,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --action_env TF_NEED_ROCM=1 -build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl --define=using_sycl=true -build:sycl --action_env TF_NEED_OPENCL_SYCL=1 - -build:sycl_nodouble --config=sycl -build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE - -build:sycl_nodouble --config=sycl -build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address - -build:sycl_nodouble --config=sycl -build:sycl_trisycl --define=using_trisycl=true - # Options extracted from configure script build:ngraph --define=with_ngraph_support=true build:numa --define=with_numa_support=true @@ -278,6 +262,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 +build:c++17_gcc --cxxopt=-std=c++1z +build:c++1z_gcc --config=c++17_gcc # Enable using platform specific build settings, except when cross-compiling for # mobile platforms. @@ -289,6 +275,7 @@ build:ios --noenable_platform_specific_config build:android --copt=-w build:ios --copt=-w build:linux --copt=-w +build:linux --host_copt=-w build:macos --copt=-w build:windows --copt=/w @@ -330,6 +317,11 @@ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN build:windows --copt=-DNOGDI build:windows --host_copt=-DNOGDI +# MSVC (Windows): Standards-conformant preprocessor mode +# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview +build:windows --copt=/experimental:preprocessor +build:windows --host_copt=/experimental:preprocessor + # Misc build options we need for windows. build:windows --linkopt=/DEBUG build:windows --host_linkopt=/DEBUG @@ -354,6 +346,7 @@ build --config=short_logs # TODO(gunan): Create a feature in toolchains for avx/avx2 to # avoid having to define linux/win separately. build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx build:avx2_linux --copt=-mavx2 build:native_arch_linux --copt=-march=native build:avx_win --copt=/arch=AVX @@ -368,7 +361,6 @@ build --config=v2 test --config=v2 # Enable XLA -build:xla --action_env=TF_ENABLE_XLA=1 build:xla --define=with_xla_support=true # BEGIN TF REMOTE BUILD EXECUTION OPTIONS @@ -408,9 +400,12 @@ build:rbe_linux --config=avx_linux build:rbe_linux --config=short_logs # TODO(gunan): Check why we need this specified in rbe, but not in other builds. build:rbe_linux --linkopt=-lrt +build:rbe_linux --host_linkopt=-lrt build:rbe_linux --linkopt=-lm +build:rbe_linux --host_linkopt=-lm build:rbe_cpu_linux --config=rbe_linux +build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" @@ -428,6 +423,7 @@ test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/ build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" @@ -444,6 +440,7 @@ build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" @@ -458,12 +455,12 @@ build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7" build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8" -# Map default to CUDA 10.1. +# Map default to CUDA 11 for PY35 and greater. build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7 -build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5 -build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6 -build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7 -build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8 +build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5 +build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6 +build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7 +build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8 # Deprecated configs that people might still use. build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36 @@ -580,11 +577,11 @@ build:release_cpu_macos --config=avx_linux build:release_gpu_common --config=release_common build:release_gpu_common --config=cuda build:release_gpu_common --config=tensorrt -build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" -build:release_gpu_common --action_env=TF_CUDA_VERSION="10" -build:release_gpu_common --action_env=TF_CUDNN_VERSION="7" +build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0" +build:release_gpu_common --action_env=TF_CUDA_VERSION="11" +build:release_gpu_common --action_env=TF_CUDNN_VERSION="8" build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" -build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70" +build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" @@ -592,8 +589,7 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" build:release_gpu_linux --config=release_gpu_common build:release_gpu_linux --config=avx_linux -build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain - +build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain build:release_windows_common --config=release_common build:release_windows_common --define=no_tensorflow_py_deps=true build:release_windows_common --announce_rc @@ -601,3 +597,8 @@ build:release_windows_common --announce_rc build:release_cpu_windows --config=release_windows_common build:release_gpu_windows --config=release_windows_common + +build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux +build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" +build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10" +build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7" diff --git a/.github/bot_config.yml b/.github/bot_config.yml index d0e7256aec0..952ff91fef7 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -40,6 +40,22 @@ segfault_memory: # assignees filesystem_security_assignee: - mihaimaruseac + +tflite_micro_path: + - tensorflow/lite/micro + +tflite_micro_comment: > + Thanks for contributing to TensorFlow Lite Micro. + + + To keep this process moving along, we'd like to make sure that you have completed the items on this list: + * Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md) + * Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) + * Linked to the issue from the PR description + + + We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review. + # Cuda Comment cuda_comment: > From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: diff --git a/ADOPTERS.md b/ADOPTERS.md deleted file mode 100644 index c0be567dc14..00000000000 --- a/ADOPTERS.md +++ /dev/null @@ -1,10 +0,0 @@ -# TensorFlow Adopters - -This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included -here, please send a pull request which modifies this file. - -We intend to use this list to contact you for surveys, and to find good candidates for invite-only events. -We will also point to this list if we are asked who uses TensorFlow. - -We will not use any of the information here for promotions or to send other regular communications. You -should subscribe to discuss@tensorflow.org for such announcements. diff --git a/CODEOWNERS b/CODEOWNERS index 3ef02ffd68c..83ad24b2845 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,16 +1,15 @@ # Where component owners are known, add them here. -/tensorflow/c/eager @jaingaurav @alextp -/tensorflow/core/common_runtime/eager @jaingaurav @alextp +/tensorflow/c/eager @qqfish @kkimdev +/tensorflow/core/common_runtime/eager @qqfish @kkimdev /tenosrflow/core/debug @caisq /tensorflow/core/nccl/ @azaks2 @chsigg -/tensorflow/core/platform/windows/ @mrry +/tensorflow/core/platform/windows/ @gunan @mihaimaruseac /tensorflow/lite/experimental/micro @petewarden @advaitjain /tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/debug @caisq -/tensorflow/python/eager @jaingaurav @alextp +/tensorflow/python/eager @rohan100jain @kkimdev /tensorflow/python/tools/api/generator/ @annarev -/tensorflow/tensorboard/ @jart /tensorflow/tools/docs/ @markdaoust /third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 6398e8e27a1..f888f6bd9d4 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Build Type * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) -* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard) diff --git a/RELEASE.md b/RELEASE.md index b0c785c7d68..89dd3a8a78c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -22,6 +22,7 @@ * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy. * Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values. * Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now. + * Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead. * Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient. * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know. * Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed. @@ -33,6 +34,18 @@ shape assumptions (note that you can pass shapes with `None` entries for axes that are meant to be dynamic). You can also disable the input checking entirely by setting `model.input_spec = None`. +* XLA:CPU and XLA:GPU devices are no longer registered by default. Use + `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be + removed). +* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type + `tf.complex64` or `tf.complex128`, because the behavior of these ops is not + well defined for complex types. +* `tf.data.experimental.service.DispatchServer` now takes a config tuple + instead of individual arguments. Usages should be updated to + `tf.data.experimental.service.DispatchServer(dispatcher_config)`. +* `tf.data.experimental.service.WorkerServer` now takes a config tuple + instead of individual arguments. Usages should be updated to + `tf.data.experimental.service.WorkerServer(worker_config)`. ## Known Caveats @@ -67,11 +80,24 @@ the same sparsity pattern, but with new provided values. It is similar to the `with_values` function of `RaggedTensor`. * Added `StatelessCase` op, and uses it if none of case branches has stateful ops. + * Added `tf.config.experimental.get_memory_usage` to return total memory usage + of the device. * `tf.data`: * Added new `tf.data.experimental.service.register_dataset` and `tf.data.experimental.service.from_dataset_id` APIs to enable one process to register a dataset with the tf.data service, and another process to consume data from the dataset. + * Added support for tf.data service dispatcher fault tolerance. To enable + fault tolerance, configure a `work_dir` when running your dispatcher + server and set `dispatcher_fault_tolerance=True`. The dispatcher will + store its state to `work_dir`, so that on restart it can continue from its + previous state after restart. + * Added tf.data service support for sharing dataset graphs via shared + filesystem instead of over RPC. This reduces load on the dispatcher, + improving performance of distributing datasets. For this to work, the + dispatcher's `work_dir` must be accessible from workers. If the worker + fails to read from the `work_dir`, it falls back to using RPC for dataset + graph transfer. * Added optional `exclude_cols` parameter to CsvDataset. This parameter is the complement of `select_cols`; at most one of these should be specified. * We have implemented an optimization which reorders data-discarding @@ -79,11 +105,19 @@ dataset when it is safe to do so. The optimization can be disabled via the `experimental_optimization.reorder_data_discarding_ops` dataset option. + * `tf.data.Options` were previously immutable and can now be overriden. + * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors + with a new `output_signature` argument, which allows `from_generator` to + produce any type describable by a `tf.TypeSpec`. + * `tf.data.experimental.AUTOTUNE` is now available in the core API as + `tf.data.AUTOTUNE`. * `tf.image`: * Added deterministic `tf.image.stateless_random_*` functions for each - `tf.image.random_*` function. Given the same seed, the stateless functions - produce the same results independent of how many times the function is - called, and independent of global seed settings. + `tf.image.random_*` function. Added a new op + `stateless_sample_distorted_bounding_box` which is a determinstic + version of `sample_distorted_bounding_box` op. Given the same seed, these + stateless functions/ops produce the same results independent of how many + times the function is called, and independent of global seed settings. * `tf.distribute`: * * `tf.keras`: @@ -95,16 +129,49 @@ * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` as an alternative to accepting a `callable` loss. + * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) + to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). + * Added `mobilenet_v3` to keras application model. + * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for + customization of how gradients are aggregated across devices, as well as + `gradients_transformers` to allow for custom gradient transformations + (such as gradient clipping). + * The `steps_per_execution` argument in `compile()` is no longer + experimental; if you were passing `experimental_steps_per_execution`, + rename it to `steps_per_execution` in your code. This argument controls + the number of batches to run during each `tf.function` call when calling + `fit()`. Running multiple batches inside a single `tf.function` call can + greatly improve performance on TPUs or small models with a large Python + overhead. * `tf.function` / AutoGraph: * Added `experimental_follow_type_hints` argument for `tf.function`. When True, the function may use type annotations to optimize the tracing performance. * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. + * AutoGraph now allows creating new symbols inside a TensorFLow loop, if + the values of these symbols at an iteration does not depend on the previous + iteration. These types of loops must run at least one iteration, and will + raise a runtime error otherwise. + + Example: + + ``` + for batch in data: + outputs = train_step(batch) + tf.print('final outputs', outputs) + ``` + See tensorflow/python/autograph/g3doc/reference/limitations.md for more + info. * `tf.lite`: * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty. * `TFLiteConverter`: * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). + * Deprecate `Interpreter::UseNNAPI(bool)` C++ API + * Prefer using `NnApiDelegate()` and related delegate configuration methods directly. + * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. + * TFLite Profiler for Android is available. See the detailed + [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). * * `tf.random`: * @@ -116,14 +183,28 @@ behavior by adjusting the `l2` parameter. * * XLA Support: + * xla.experimental.compile is deprecated, use + `tf.function(experimental_compile=True)` instead + * Added `tf.function.experimental_get_compiler_ir` which returns compiler IR + (currently 'hlo' and 'optimized_hlo') for given input for given function. * * Tracing and Debugging: * +* `tf.train.Checkpoint`: + * Now accepts a `root` argument in the initialization, which generates a + checkpoint with a root object. This allows users to create a `Checkpoint` + object that is compatible with Keras `model.save_weights()` and + `model.load_weights`. The checkpoint is also compatible with the + checkpoint saved in the `variables/` folder in the SavedModel. + * When restoring, `save_path` can be a path to a SavedModel. The function + will automatically find the checkpoint in the SavedModel. +* `tf.nn`: + * `tf.nn.max_pool2d` now supports explicit padding. * Other: * We have replaced uses of "whitelist" and "blacklist" with "allowlist" and "denylist" where possible. Please see https://developers.google.com/style/word-list#blacklist for more context. - * + ## Thanks to our Contributors @@ -215,6 +296,7 @@ stjohnso98, , , , , * Mutable tables now restore checkpointed values when loaded from SavedModel. * GPU * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. + * Remove environmental variable `TF_USE_CUDNN`. * Others * Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`. * Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations. @@ -1546,6 +1628,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric, color palette of the frame. This has been fixed now * image.resize now considers proper pixel centers and has new kernels (incl. anti-aliasing). + * Added an isotonic regression solver (tf.nn.isotonic_regression). * Performance * Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically dispatches the best kernel implementation based on CPU vector diff --git a/configure.cmd b/configure.cmd index 021afdbbea1..738e106da18 100644 --- a/configure.cmd +++ b/configure.cmd @@ -16,5 +16,5 @@ set configure_dir=%~dp0 set configure_dir=%configure_dir:~0,-1% -python %configure_dir%\configure.py %* || ( exit /b ) +python "%configure_dir%\configure.py" %* || ( exit /b ) echo Configuration finished diff --git a/configure.py b/configure.py index 9524eada3cd..5b9fd55b740 100644 --- a/configure.py +++ b/configure.py @@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_TENSORRT_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' -_TF_OPENCL_VERSION = '1.2' -_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' -_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp): write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) -def set_computecpp_toolkit_path(environ_cp): - """Set COMPUTECPP_TOOLKIT_PATH.""" - - def toolkit_exists(toolkit_path): - """Check if a computecpp toolkit path is valid.""" - if is_linux(): - sycl_rt_lib_path = 'lib/libComputeCpp.so' - else: - sycl_rt_lib_path = '' - - sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - exists = os.path.exists(sycl_rt_lib_path_full) - if not exists: - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - return exists - - computecpp_toolkit_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='COMPUTECPP_TOOLKIT_PATH', - var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, - ask_for_var=( - 'Please specify the location where ComputeCpp for SYCL %s is ' - 'installed.' % _TF_OPENCL_VERSION), - check_success=toolkit_exists, - error_msg='Invalid SYCL compiler path. %s cannot be found.', - suppress_default_error=True) - - write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', - computecpp_toolkit_path) - - -def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR.""" - - ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' - 'include directory. (Use --config=sycl_trisycl ' - 'when building with Bazel) ' - '[Default is %s]: ') % ( - _DEFAULT_TRISYCL_INCLUDE_DIR) - - while True: - trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) - if os.path.exists(trisycl_include_dir): - break - - print('Invalid triSYCL include directory, %s cannot be found' % - (trisycl_include_dir)) - - # Set TRISYCL_INCLUDE_DIR - environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) - - def system_specific_test_config(environ_cp): """Add default build and test flags required for TF tests to bazelrc.""" write_to_bazelrc('test --flaky_test_attempts=3') @@ -1397,8 +1338,6 @@ def main(): setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_OPENCL_SYCL'] = '0' - environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' @@ -1415,21 +1354,6 @@ def main(): if environ_cp.get('TF_ENABLE_XLA', '1') == '1': write_to_bazelrc('build --config=xla') - set_action_env_var( - environ_cp, - 'TF_NEED_OPENCL_SYCL', - 'OpenCL SYCL', - False, - bazel_config_name='sycl') - if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': - set_host_cxx_compiler(environ_cp) - set_host_c_compiler(environ_cp) - set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True) - if environ_cp.get('TF_NEED_COMPUTECPP') == '1': - set_computecpp_toolkit_path(environ_cp) - else: - set_trisycl_include_dir(environ_cp) - set_action_env_var( environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') if (environ_cp.get('TF_NEED_ROCM') == '1' and @@ -1442,6 +1366,11 @@ def main(): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) + if ((environ_cp.get('TF_NEED_ROCM') == '1') and + (environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')): + write_to_bazelrc( + 'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1') + environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) if (environ_cp.get('TF_NEED_CUDA') == '1' and @@ -1523,17 +1452,15 @@ def main(): # use it for the CPU build. set_tf_download_clang(environ_cp) - # SYCL / ROCm / CUDA are mutually exclusive. + # ROCm / CUDA are mutually exclusive. # At most 1 GPU platform can be configured. gpu_platform_count = 0 - if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': - gpu_platform_count += 1 if environ_cp.get('TF_NEED_ROCM') == '1': gpu_platform_count += 1 if environ_cp.get('TF_NEED_CUDA') == '1': gpu_platform_count += 1 if gpu_platform_count >= 2: - raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' + raise UserInputError('CUDA / ROCm are mututally exclusive. ' 'At most 1 GPU platform can be configured.') set_cc_opt_flags(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d1c1d7dcdef..668f3a55579 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -562,6 +562,7 @@ selects.config_setting_group( package_group( name = "internal", packages = [ + "//learning/brain/distribute/...", "//learning/brain/swift/x10/...", "//perftools/accelerators/xprof/api/...", "//tensorflow/...", @@ -578,11 +579,6 @@ package_group( packages = ["//learning/pathways/..."], ) -# Packages that use composite tensors or dispatch. -# TODO(b/154762408) Remove this package group once it's no longer needed. -# If this is modified, then copy.bara.sky must also be modified. -package_group(name = "composite_tensor_whitelist") - # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. package_group( diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 0cd2b7da139..5932dda514d 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -137,7 +137,7 @@ if _running_from_pip_package(): # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: # Load first party dynamic kernels. - _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') if _fi.file_exists(_main_dir): _ll.load_library(_main_dir) diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index b73af197f7b..0d1d2e56fae 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -147,7 +147,7 @@ if _running_from_pip_package(): # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: # Load first party dynamic kernels. - _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') if _fi.file_exists(_main_dir): _ll.load_library(_main_dir) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index e5efe323922..01f48cad192 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -23,6 +23,7 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", + "c_api_macros.h", "tensor_interface.h", "tf_attrtype.h", "tf_datatype.h", @@ -57,10 +58,11 @@ filegroup( visibility = ["//visibility:public"], ) -filegroup( +cc_library( name = "pywrap_required_hdrs", - srcs = [ + textual_hdrs = [ "c_api_internal.h", + "c_api_macros.h", "conversion_macros.h", "python_api.h", "tensor_interface.h", @@ -79,6 +81,7 @@ tf_cuda_library( hdrs = [ "c_api.h", "c_api_internal.h", + "c_api_macros.h", "tf_datatype.h", "tf_tensor.h", "tf_tstring.h", @@ -217,6 +220,7 @@ cc_library( name = "logging", srcs = ["logging.cc"], hdrs = ["logging.h"], + visibility = ["//visibility:public"], deps = [ ":c_api_macros", "//tensorflow/core/platform:logging", @@ -310,6 +314,7 @@ cc_library( hdrs = ["tf_tensor.h"], visibility = ["//visibility:public"], deps = [ + ":c_api_macros", ":tensor_interface", ":tf_datatype", ":tf_status", @@ -336,6 +341,7 @@ tf_cuda_library( ], visibility = ["//tensorflow:internal"], deps = [ + ":c_api_macros", ":tensor_interface", ":tf_datatype", ":tf_status", @@ -371,6 +377,7 @@ tf_cuda_library( "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:get_compiler_ir", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -381,6 +388,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/platform", + "//tensorflow/core/platform:blocking_counter", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index b4297033b6d..81fb9d1a2b8 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/net.h" @@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + auto collective_executor_handle = context->GetCollectiveExecutorHandle(); + tensorflow::Notification done; + collective_executor_handle->get()->remote_access()->CheckPeerHealth( + task, [&done, status](const Status& s) { + status->status = s; + done.Notify(); + }); + done.WaitForNotification(); +} + TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; result->num_items = num_items; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index ebd14b4b571..c9c74f4e874 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -231,13 +231,20 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, TF_Status* status); // Aborts all ongoing collectives with the specified status. After abortion, -// subsequent collectives will error with this status immediately. +// subsequent collectives will error with this status immediately. To reset the +// collectives, create a new EagerContext. // -// This is intended to be used when a peer failure is detected. There's yet no -// way to reset the collectives other than restarting the program. +// This is intended to be used when a peer failure is detected. TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, TF_Status* status); +// Checks the health of collective ops peers. Explicit health check is needed in +// multi worker collective ops to detect failures in the cluster. If a peer is +// down, collective ops may hang. +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status); + // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { // Number of dimensions. -1 indicates unknown rank. diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 3fff9bcd371..ec8cfe4a31a 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } -// This test only works when the TF build includes XLA compiler. One way to set -// this up is via bazel build option "--define with_xla_support=true". -// -// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to -// something like TENSORFLOW_CAPI_USE_XLA. -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST_F(CApiFunctionTest, StatelessIf_XLA) { - TF_Function* func; - const std::string funcName = "BranchFunc"; - DefineFunction(funcName.c_str(), &func); - TF_GraphCopyFunction(host_graph_, func, nullptr, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* feed = Placeholder(host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_OperationDescription* desc = - TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); - TF_AddInput(desc, {true_cond, 0}); - TF_Output inputs[] = {{feed, 0}}; - TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_SetAttrType(desc, "Tcond", TF_BOOL); - TF_DataType inputType = TF_INT32; - TF_SetAttrTypeList(desc, "Tin", &inputType, 1); - TF_SetAttrTypeList(desc, "Tout", &inputType, 1); - TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); - TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); - TF_SetDevice(desc, "/device:XLA_CPU:0"); - auto op = TF_FinishOperation(desc, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - ASSERT_NE(op, nullptr); - - // Create a session for this graph. - CSession csession(host_graph_, s_, /*use_XLA*/ true); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - // Run the graph. - csession.SetInputs({{feed, Int32Tensor(17)}}); - csession.SetOutputs({op}); - csession.Run(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Tensor* out = csession.output_tensor(0); - ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* output_contents = static_cast(TF_TensorData(out)); - EXPECT_EQ(-17, *output_contents); - - // Clean up - csession.CloseAndDelete(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_DeleteFunction(func); -} -#endif // TENSORFLOW_EAGER_USE_XLA - } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_macros.h b/tensorflow/c/c_api_macros.h index 85c9507db87..e0c91a0d549 100644 --- a/tensorflow/c/c_api_macros.h +++ b/tensorflow/c/c_api_macros.h @@ -30,4 +30,17 @@ limitations under the License. #endif // _WIN32 #endif // SWIG +// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is +// the datatype for boolean tensors. +#ifndef TF_Bool +#define TF_Bool unsigned char +#endif // TF_Bool + +// Macro used to calculate struct size for maintaining ABI stability across +// different struct implementations. +#ifndef TF_OFFSET_OF_END +#define TF_OFFSET_OF_END(TYPE, MEMBER) \ + (offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER)) +#endif // TF_OFFSET_OF_END + #endif // TENSORFLOW_C_C_API_MACROS_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 61701bc8b21..d259b32f339 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -6,7 +6,6 @@ load( "tf_copts", "tf_cuda_cc_test", "tf_cuda_library", - "tfe_xla_copts", ) load( "//tensorflow/core/platform:build_config.bzl", @@ -31,7 +30,7 @@ tf_cuda_library( "c_api_unified_experimental.h", ], hdrs = ["c_api.h"], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -72,13 +71,6 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:xla_device", - ], - "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", @@ -109,11 +101,17 @@ filegroup( "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", + "c_api_unified_experimental_internal.h", "dlpack.h", + "gradients.h", + "gradients_internal.h", "immediate_execution_context.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", + "mnist_gradients_testutil.h", + "tape.h", "tfe_cancellation_manager_internal.h", + "tfe_context_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", "tfe_op_attrs_internal.h", @@ -171,31 +169,6 @@ cc_library( ], ) -cc_library( - name = "gradients", - srcs = [ - "gradients.cc", - "gradients_internal.h", - ], - hdrs = [ - "gradients.h", - ], - visibility = [ - "//tensorflow:internal", - ], - deps = [ - ":abstract_context", - ":abstract_operation", - ":abstract_tensor_handle", - ":c_api_unified_internal", - ":tape", - "//tensorflow/core/common_runtime/eager:attr_builder", - "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "gradients_internal", srcs = [ @@ -228,7 +201,6 @@ tf_cuda_cc_test( "gradients_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -240,6 +212,7 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:math_grad", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/cc/profiler", @@ -249,6 +222,184 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "gradients_util", + srcs = [ + "gradients_util.cc", + ], + hdrs = [ + "gradients_util.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", + ":c_api", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":tape", + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "mnist_gradients_testutil", + srcs = [ + "mnist_gradients_testutil.cc", + ], + hdrs = [ + "mnist_gradients_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":gradients_util", + ":tape", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "gradient_checker", + srcs = [ + "gradient_checker.cc", + ], + hdrs = [ + "gradient_checker.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":gradients_util", + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cuda_cc_test( + name = "gradient_checker_test", + size = "small", + srcs = [ + "gradient_checker_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + ":gradient_checker", + ":gradients_internal", + ":gradients_util", + ":mnist_gradients_testutil", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cuda_cc_test( + name = "mnist_gradients_test", + size = "small", + srcs = [ + "mnist_gradients_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + [ + "nomac", + ], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_unified_internal", + ":gradients_internal", + ":gradients_util", + ":mnist_gradients_testutil", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -482,7 +633,6 @@ tf_cuda_cc_test( "c_api_debug_test.cc", "c_api_test.cc", ], - extra_copts = tfe_xla_copts(), tags = [ "noguitar", # TODO(b/155445984): flaky #"guitar", @@ -508,6 +658,27 @@ tf_cuda_cc_test( ], ) +tf_cuda_library( + name = "c_api_remote_test_util", + testonly = 1, + srcs = ["c_api_remote_test_util.cc"], + hdrs = ["c_api_remote_test_util.h"], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":c_api", + ":c_api_internal", + ":c_api_test_util", + ":tfe_tensorhandle_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_cc_test( name = "c_api_remote_test", size = "small", @@ -516,7 +687,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -524,6 +694,7 @@ tf_cuda_cc_test( ":c_api", ":c_api_experimental", ":c_api_internal", + ":c_api_remote_test_util", ":c_api_test_util", ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", @@ -540,6 +711,24 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "c_api_remote_function_test", + size = "small", + srcs = [ + "c_api_remote_function_test.cc", + ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], + tags = [ + "no_windows", + ], + deps = [ + ":c_api_remote_test_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_cc_test( name = "c_api_distributed_test", size = "small", @@ -548,7 +737,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", "noasan", # leaks gRPC server instances @@ -582,7 +770,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -617,7 +804,7 @@ tf_cuda_library( "c_api_experimental.h", "c_api_unified_experimental.h", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -689,7 +876,6 @@ tf_cuda_cc_test( "c_api_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -702,6 +888,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", ], ) @@ -713,7 +900,6 @@ tf_cuda_cc_test( "c_api_unified_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -722,6 +908,7 @@ tf_cuda_cc_test( ":c_api_test_util", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", @@ -831,7 +1018,11 @@ filegroup( "c_api_unified_experimental_eager.cc", "c_api_unified_experimental_graph.cc", "c_api_unified_experimental_internal.h", + "gradient_checker.cc", + "gradient_checker.h", "gradients.cc", # Uses RTTI. + "gradients_util.cc", + "gradients_util.h", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 76d603694e3..fb5ce22ae5f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -51,9 +51,6 @@ limitations under the License. #include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#endif // TENSORFLOW_EAGER_USE_XLA #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -629,21 +626,30 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( "targets will fail."; } } else { - // The master's context_view_id will be incremented by one - // the UpdateRemoteMaster call later. We want all new workers and - // existing workers to also have the updated context_view_id, so - // we must set their context_view_id to the existing master's - // context_view_id + 1. - sg.Update(CreateRemoteContexts( - ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); + if (sg.ok()) { + // Create remote contexts on the newly added workers only if the master + // has collected all device information from them (i.e., the + // GetAllRemoteDevices call returns succussfully). Note that in rare cases + // GetAllRemoteDevices can still fail even with RPCs configured to wait + // until the remote workers to become alive. If the master creates remote + // contexts on the workers whose devices are still not collected, those + // workers will be treated as existing workers subsequently, so the master + // will never get devices from them even with retrying UpdateServerDef. + sg.Update(CreateRemoteContexts( + ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, + server_def, remote_eager_workers.get(), context->Executor().Async(), + context->LazyCopyFunctionRemoteInputs(), base_request)); + } if (!existing_workers.empty()) { if (VLOG_IS_ON(1)) { for (const string& w : existing_workers) { VLOG(1) << "Updating cluster with existing worker " << w; } } + // The master's context_view_id will be incremented by one in the + // UpdateRemoteMaster call later. We want existing workers to also have + // the updated context_view_id, so we must set their context_view_id to + // the master's current context_view_id + 1. sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, removed_workers, context_id, context_view_id + 1, server_def, @@ -724,7 +730,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #ifdef PLATFORM_GOOGLE - return tensorflow::wrap(new tfrt::ContextInterface(opts->async)); + return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; @@ -745,7 +751,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->session_options.options, static_cast( opts->device_placement_policy), - static_cast(opts->mirroring_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), /*device_mgr_owned*/ true, r, tensorflow::GetDefaultCustomKernelCreator())); @@ -851,20 +856,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - static_cast(context->GetServer()); - - std::unique_ptr remote_eager_workers; - status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache( - &remote_eager_workers); - if (!status->status.ok()) { - LOG(ERROR) << "Failed to get client cache for remote workers."; - return false; - } - // TODO(yuefengz): support partially specified `worker_name`. tensorflow::core::RefCountPtr eager_client; - status->status = remote_eager_workers->GetClient(worker_name, &eager_client); + status->status = context->GetClient(worker_name, &eager_client); if (!status->status.ok()) { return false; } @@ -1149,26 +1143,23 @@ void TFE_DeleteOp(TFE_Op* op) { tensorflow::unwrap(op)->Release(); } +const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->Name().c_str(); +} + +TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { + return tensorflow::wrap( + &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); +} + void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); } -const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { +const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) { return tensorflow::unwrap(op)->DeviceName().c_str(); } -void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { -#ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable); - if (!s.ok()) { - LOG(ERROR) << "Could not enable XLA compilation for op: " << s; - } -#else - LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " - "built with XLA support."; -#endif // TENSORFLOW_EAGER_USE_XLA -} - void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); } @@ -1181,6 +1172,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, static_cast(num_inputs)}); } +extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->GetInputs().size(); +} + +extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index, + TF_Status* status) { + return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]); +} + TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret = TF_ATTR_INT; @@ -1486,7 +1486,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) { tensorflow::unwrap(ctx)->EndStep(); } -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { +const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) { return tensorflow::wrap( &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); } @@ -1551,8 +1551,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, TFE_OpSetAttrFunction(op, attr_name, func_op); TFE_DeleteOp(func_op); } break; - case tensorflow::AttrValue::kList: - TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kList: { + // String + if (const int s_size = default_value.list().s_size()) { + absl::InlinedVector values_vector; + absl::InlinedVector lengths_vector; + for (int i = 0; i < s_size; ++i) { + const string& v = default_value.list().s(i); + values_vector.push_back(v.data()); + lengths_vector.push_back(v.size()); + } + TFE_OpSetAttrStringList(op, attr_name, values_vector.data(), + lengths_vector.data(), s_size); + } + + // Int + if (const int i_size = default_value.list().i_size()) { + absl::InlinedVector i_vector; + for (int i = 0; i < i_size; ++i) { + i_vector.push_back(default_value.list().i(i)); + } + TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size); + } + // Float + if (const int f_size = default_value.list().f_size()) { + absl::InlinedVector f_vector; + for (int i = 0; i < f_size; ++i) { + f_vector.push_back(default_value.list().f(i)); + } + TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size); + } + // Bool + if (const int b_size = default_value.list().b_size()) { + absl::InlinedVector b_vector; + for (int i = 0; i < b_size; i++) { + b_vector.push_back(default_value.list().b(i)); + } + TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size); + } + // Type + if (const int type_size = default_value.list().type_size()) { + absl::InlinedVector type_vector; + for (int i = 0; i < type_size; ++i) { + type_vector.push_back(default_value.list().type(i)); + } + TFE_OpSetAttrTypeList( + op, attr_name, + reinterpret_cast(type_vector.data()), + type_size); + } + + // Rest are not supported. + if (default_value.list().shape_size() > 0 || + default_value.list().func_size() > 0 || + default_value.list().tensor_size() > 0) { + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } + } break; case tensorflow::AttrValue::kTensor: TF_FALLTHROUGH_INTENDED; case tensorflow::AttrValue::kPlaceholder: @@ -1612,19 +1671,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status Execute(tensorflow::EagerOperation* op, + tensorflow::Status Execute(const tensorflow::EagerOperation* op, tensorflow::TensorHandle** retvals, int* num_retvals) override { - std::vector inputs; - inputs.reserve(op->Inputs().size()); - for (int i = 0; i < op->Inputs().size(); ++i) { - op->Inputs()[i]->Ref(); - inputs.push_back(tensorflow::wrap(op->Inputs()[i])); - } std::vector outputs(*num_retvals); TF_Status status; - device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { @@ -1634,10 +1686,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_DeleteTensorHandle(outputs[i]); } } - - for (auto inp : inputs) { - TFE_DeleteTensorHandle(inp); - } return status.status; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 5afe3047dd7..a58c681e8fe 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); - TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); +// Returns the op or function name `op` will execute. +// +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op, + TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); // The returned string remains valid throughout the lifetime of 'op'. -TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status); -// When 'enable' is set to 1, and if TensorFlow library is built with XLA -// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. -// -// If the library is not built with XLA support, this call would be a no-op. -TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, - unsigned char enable); - TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status); @@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op, int num_inputs, TF_Status* status); +// Fetches the current number of inputs attached to `op`. +// +// Does not use the operation's definition to determine how many inputs should +// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an +// already-finalized operation. +// +// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat +// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a +// particular named input list, which may only be part of the op's inputs). +TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op, + TF_Status* status); +// Returns a borrowed reference to one of `op`'s inputs. Use +// `TFE_TensorHandleCopySharingTensor` to make a new reference. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, + int index, + TF_Status* status); + TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index dd55f05283b..b5721cdab0a 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -22,9 +22,6 @@ limitations under the License. #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/platform/status.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/jit/xla_device.h" -#endif // TENSORFLOW_EAGER_USE_XLA using tensorflow::string; @@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } -#ifdef TENSORFLOW_EAGER_USE_XLA - auto* device = absl::get(handle->device()); - - // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. - auto* xla_device = dynamic_cast(device); - if (xla_device != nullptr) { - tensorflow::XlaDevice::PaddedShapeFn shape_fn = - xla_device->metadata().padded_shape_fn(); - xla::Shape padded_shape; - status->status = shape_fn(*tensor, &padded_shape); - if (!status->status.ok()) { - return nullptr; - } - if (VLOG_IS_ON(3)) { - std::vector shape_to_log = - TensorShapeAsVector(*handle, &status->status); - if (!status->status.ok()) { - // Ignore the status here as we are simply logging. - status->status = tensorflow::Status::OK(); - } else { - VLOG(3) << "Fully padded shape of [" - << absl::StrJoin(shape_to_log, ", ") << "] is " - << padded_shape.DebugString(); - } - } - - if (padded_shape.IsTuple()) { - if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { - // Currently, the only case of XlaTensor containing a tuple shape is to - // represent 64 bit ints, doubles, and complex numbers (we don't support - // 64bit complex numbers). - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should only contain tuples of size 2. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // shape0 is not a const& because we will assign it to padded_shape below. - // It is illegal to assign a part of a message to itself. - xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); - const xla::Shape& shape1 = - xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (shape0.IsTuple() || shape1.IsTuple()) { - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should not contain nested tuples. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - if (!xla::ShapeUtil::Equal(shape0, shape1)) { - status->status = tensorflow::errors::InvalidArgument( - "Subshapes of XlaTensors should be the same. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // Since the only case we handle here are two equal subshapes, we - // simply return one of them. The caller will interpret it as this - // shape directly storing the 64bit types. This approximation is good - // enough for this API's debugging use case. - padded_shape = shape0; - } - - int rank = padded_shape.dimensions_size(); - std::vector dev_dims; - dev_dims.reserve(rank); - if (rank == 1) { - // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, - dev_dims.push_back(padded_shape.dimensions(0)); - } else { - for (int i = rank - 1; i >= 0; --i) { - tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i); - dev_dims.push_back(padded_shape.dimensions(dim_index)); - } - } - status->status = tensorflow::Status::OK(); - return new TFE_TensorDebugInfo(dev_dims); - } -#endif // TENSORFLOW_EAGER_USE_XLA - - // If the tensor is not an XLA tensor, the device shape is - // the same as regular tensor shape. std::vector dev_dims = TensorShapeAsVector(*handle, &status->status); if (!status->status.ok()) { diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 3738768cf02..2718c75c3ee 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -121,25 +121,6 @@ string AddVariablesFunction() { return def.SerializeAsString(); } -void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) { - TF_Status* status = TF_NewStatus(); - TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_OpAddInput(op, var_handle, status); - TFE_TensorHandle* is_initialized[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(op, &is_initialized[0], &num_retvals, status); - CHECK_EQ(1, num_retvals); - TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status); - bool initialized = false; - memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t)); - EXPECT_EQ(initialized, true); - TF_DeleteTensor(t); - TFE_DeleteTensorHandle(is_initialized[0]); - TFE_DeleteOp(op); - delete status; -} - void TestFunctionWithPackedInput(const bool remote) { tensorflow::ServerDef server_def = GetServerDef(3); @@ -182,9 +163,8 @@ void TestFunctionWithPackedInput(const bool remote) { // Add a sync point in order to make sure that variables have been initialized // before the function execution starts. - // TODO(b/155789951): Remove once b/155789951 is fixed. - VarIsInitialized(ctx, h1); - VarIsInitialized(ctx, h2); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Pack 3 variable handles into one TFE_TensorHandle. // When remote is false, function device is placed on task0. Handle types are @@ -396,6 +376,8 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) { TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); EXPECT_NE(var_handle, nullptr); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const string function_def = VariableAddFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -517,8 +499,11 @@ void TestDistributedFunctionCancellation(bool inject_error) { TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); EXPECT_NE(var_handle, nullptr); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - const string function_def = VariableAddFunctionWithGraphError(); + const string function_def = inject_error ? VariableAddFunctionWithGraphError() + : VariableAddFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 7390cf243be..eabb159a631 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -486,29 +486,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( static_cast(sampler->sampler->GetCell(label1, label2))); } -void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, - TFE_ContextMirroringPolicy policy) { - options->mirroring_policy = policy; -} - -void TFE_ContextSetThreadLocalMirroringPolicy( - TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetThreadLocalMirroringPolicy( - static_cast(policy)); -} - -// Note: this function looks up a thread local policy. So it should be called in -// the appropriate client thread. In particular, in async mode, it may not be -// safe to call this function from the async EagerExecutor threads. -extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( - TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return static_cast(context->GetMirroringPolicy()); -} - void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, bool lazy_copy) { options->lazy_remote_inputs_copy = lazy_copy; diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 1af76c01154..12546c6082a 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -265,33 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); -// LINT.IfChange -// Note: Keep in sync with internal copy of enum in eager/context.h. -typedef enum TFE_ContextMirroringPolicy { - // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle - // copies with their own lifetime. - TFE_MIRRORING_NONE = 0, - // Mirroring any remote tensor handles, associating them with the lifetime of - // the local TensorHandle. - TFE_MIRRORING_ALL = 1, -} TFE_ContextMirroringPolicy; -// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) - -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy( - TFE_ContextOptions*, TFE_ContextMirroringPolicy); - -// Sets a thread-local mirroring policy. After this call, other calls to -// TFE_Execute in the same thread will use the mirroring policy specified here -// instead of the mirroring policy used to construct the context. This has no -// effect on the mirroring policy used by other program threads. -TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy( - TFE_Context*, TFE_ContextMirroringPolicy); - -// Returns the mirroring policy to be used by this context in the current -// thread. -TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( - TFE_Context*); - // Sets whether to copy the remote inputs of a function lazily. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TFE_ContextOptions*, bool lazy_copy); @@ -441,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs; // Fetch a reference to `op`'s attributes. The returned reference is only valid // while `op` is alive. -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); +TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. @@ -462,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, size_t proto_len, TF_Status* status); -#define TFE_CUSTOM_DEVICE_VERSION 2 +// TODO(b/166642410): It would be nice, for custom devices and for other users, +// to have a non-string representation of devices (TF_Device) extracted from +// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. + +#define TFE_CUSTOM_DEVICE_VERSION 3 // Struct to be filled in typedef struct TFE_CustomDevice { @@ -481,9 +458,16 @@ typedef struct TFE_CustomDevice { void* device_info); // Method to execute an operation. - void (*execute)(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, + // + // Arguments provide enough information to reconstruct the original `TFE_Op`, + // or construct a transformed version, by inspecting the passed `op`. + // + // TFE_OpGetDevice(op) records the original placement of the operation. It may + // be an empty string if no device was explicitly requested, but will + // otherwise be the name of this custom device. Ops are placed onto a custom + // device if any of their inputs are on that custom device, but custom devices + // are free to set a bad status in order to require explicit placement. + void (*execute)(const TFE_Op* op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info); // Method to delete a device. diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index a4d31417073..4975d303375 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Function_ident_XLA_CPU) { - // First create a simple identity function. - TF_Graph* function_graph = TF_NewGraph(); - TF_OperationDescription* arg_descr = - TF_NewOperation(function_graph, "Placeholder", "arg"); - TF_SetAttrType(arg_descr, "dtype", TF_INT32); - TF_Status* status = TF_NewStatus(); - TF_Operation* arg = TF_FinishOperation(arg_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_OperationDescription* id_descr = - TF_NewOperation(function_graph, "Identity", "id"); - TF_SetAttrType(id_descr, "T", TF_INT32); - TF_AddInput(id_descr, {arg, 0}); - TF_Operation* id = TF_FinishOperation(id_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_Output input{arg, 0}; - TF_Output output{id, 0}; - TF_Function* fn = - TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, - &output, nullptr, nullptr, "test", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteGraph(function_graph); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - TFE_ContextAddFunction(ctx, fn, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteFunction(fn); - - for (bool async : {false, true, false}) { - TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); - TFE_Executor* executor = TFE_NewExecutor(async); - TFE_ContextSetExecutorForThread(ctx, executor); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); - - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); - - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); - - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_ContextSetExecutorForThread(ctx, old_executor); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - TFE_DeleteExecutor(old_executor); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); - } - TFE_ContextRemoveFunction(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContext(ctx); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteStatus(status); -} -#endif // TENSORFLOW_EAGER_USE_XLA - void Executor_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4d9be0c2501..356476c2186 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -32,7 +32,6 @@ struct TFE_ContextOptions { bool async = false; TFE_ContextDevicePlacementPolicy device_placement_policy{ TFE_DEVICE_PLACEMENT_SILENT}; - TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; // If true, lazily copy the remote inputs of a function to the target devices. bool lazy_remote_inputs_copy = true; // If true, use TFRT backend diff --git a/tensorflow/c/eager/c_api_remote_function_test.cc b/tensorflow/c/eager/c_api_remote_function_test.cc new file mode 100644 index 00000000000..a9bbd5b694f --- /dev/null +++ b/tensorflow/c/eager/c_api_remote_function_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_remote_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote, + bool heavy_load_on_streaming_rpc, + bool remote_func_outputs = false) { + return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true, + heavy_load_on_streaming_rpc, + remote_func_outputs); +} + +TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { + // A remote input may be not ready when we start running a function. Test that + // the function execution should wait until the remote input is ready. + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, + /*heavy_load_on_streaming_rpc=*/true); +} + +} // namespace diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 94c32cf3f30..e68e15ba560 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_remote_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -115,225 +117,24 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } -string MatMulFunction() { - tensorflow::FunctionDef def; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString( - " signature {" - " name: 'MatMulFunction'" - " input_arg {" - " name: 'a'" - " type: DT_FLOAT" - " }" - " input_arg {" - " name: 'b'" - " type: DT_FLOAT" - " }" - " output_arg {" - " name: 'm'" - " type: DT_FLOAT" - " }" - " }" - " node_def {" - " name: 'matmul'" - " op: 'MatMul'" - " input: 'a'" - " input: 'b'" - " attr {" - " key: 'T'" - " value {" - " type: DT_FLOAT" - " }" - " }" - " }" - " ret {" - " key: 'm'" - " value: 'matmul:product'" - " }", - &def)); - return def.SerializeAsString(); -} - -// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one -// which creates a remote remote input, to simulate a scenario that the remote -// input is not ready when we start running an op or a function. -void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, - bool heavy_load_on_streaming_rpc) { - tensorflow::ServerDef server_def = GetServerDef(3); - - // This server def has the task index set to 0. - string serialized = server_def.SerializeAsString(); - - server_def.set_task_index(1); - std::unique_ptr worker_server1; - ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server1) - .ok()); - ASSERT_TRUE(worker_server1->Start().ok()); - - server_def.set_task_index(2); - std::unique_ptr worker_server2; - ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server2) - .ok()); - ASSERT_TRUE(worker_server2->Start().ok()); - - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); - TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); - TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); - std::vector handles_task0; - if (heavy_load_on_streaming_rpc) { - // Send 50 tensor copy requests to simulate that there have been some RPC - // requests been enqueued. - for (int i = 0; i < 50; ++i) { - handles_task0.push_back(TestMatrixTensorHandle(ctx)); - } - } - const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; - const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; - - std::vector handles_task2; - for (auto* h_task0 : handles_task0) { - handles_task2.push_back( - TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status)); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - } - - auto* h1_task2 = - TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - - TFE_Op* matmul = nullptr; - if (func) { - string function_def = MatMulFunction(); - TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), - status); - CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - - matmul = TFE_NewOp(ctx, "MatMulFunction", status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_OpAddInput(matmul, h0_task0, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_OpAddInput(matmul, h1_task2, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - } else { - // Handles are on task0 (local), and task2, but op is on task1. - matmul = MatMulOp(ctx, h0_task0, h1_task2); - } - if (remote) { - TFE_OpSetDevice(matmul, task1_name, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - } else if (!async) { - // Set the local device to CPU to easily validate mirroring - string cpu_device_name; - ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); - TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - auto remote_arg = - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); - // The input handles should never change since they have been mirrored. - ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); - } - - TFE_TensorHandle* retvals[1]; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - - // TODO(gjn): Add support for waiting on async local mirrors - if (!remote && !async) { - auto remote_arg = - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); - // The input handles should never change since they have been mirrored. - ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); - } - - auto* retval_task0 = TFE_TensorHandleCopyToDevice( - retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - - TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_DeleteTensorHandle(retval_task0); - float product[4] = {0}; - EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); - memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(7, product[0]); - EXPECT_EQ(10, product[1]); - EXPECT_EQ(15, product[2]); - EXPECT_EQ(22, product[3]); - - TFE_DeleteTensorHandle(h0_task0); - TFE_DeleteTensorHandle(h1_task0); - TFE_DeleteTensorHandle(h1_task2); - TFE_DeleteTensorHandle(retvals[0]); - for (auto* h : handles_task0) { - TFE_DeleteTensorHandle(h); - } - for (auto* h : handles_task2) { - TFE_DeleteTensorHandle(h); - } - - TFE_DeleteOp(matmul); - - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_DeleteExecutor(executor); - if (func) { - TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); - } - TFE_DeleteContext(ctx); - - TF_DeleteStatus(status); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server1.release(); - worker_server2.release(); +void TestRemoteExecuteSilentCopiesOp(bool async, bool remote, + bool remote_func_outputs = false) { + return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false, + remote_func_outputs); } TEST(CAPI, RemoteExecuteSilentCopies) { - TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true, - /*func=*/false, - /*heavy_load_on_streaming_rpc=*/false); + TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true); } TEST(CAPI, RemoteExecuteSilentCopiesAsync) { - TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false, - /*heavy_load_on_streaming_rpc=*/false); -} -TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { - TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true, - /*heavy_load_on_streaming_rpc=*/false); + TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true); } TEST(CAPI, RemoteExecuteSilentCopiesLocal) { - TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false, - /*func=*/false, - /*heavy_load_on_streaming_rpc=*/false); + TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { - TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, - /*func=*/false, - /*heavy_load_on_streaming_rpc=*/false); -} -TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { - TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, - /*heavy_load_on_streaming_rpc=*/false); -} -TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { - // A remote input may be not ready when we start running a function. Test that - // the function execution should wait until the remote input is ready. - TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, - /*heavy_load_on_streaming_rpc=*/true); + TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false); } } // namespace diff --git a/tensorflow/c/eager/c_api_remote_test_util.cc b/tensorflow/c/eager/c_api_remote_test_util.cc new file mode 100644 index 00000000000..159fa442a73 --- /dev/null +++ b/tensorflow/c/eager/c_api_remote_test_util.cc @@ -0,0 +1,222 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_remote_test_util.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +using ::tensorflow::string; + +string MatMulFunction(const string& matmul_device) { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + absl::StrCat(" signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " input_arg {" + " name: 'b'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'b'" + " device: '", + matmul_device, "'", + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }"), + &def)); + return def.SerializeAsString(); +} + +void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, + bool heavy_load_on_streaming_rpc, + bool remote_func_outputs) { + tensorflow::ServerDef server_def = GetServerDef(3); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); + std::vector handles_task0; + if (heavy_load_on_streaming_rpc) { + // Send 50 tensor copy requests to simulate that there have been some RPC + // requests been enqueued. + for (int i = 0; i < 50; ++i) { + handles_task0.push_back(TestMatrixTensorHandle(ctx)); + } + } + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + std::vector handles_task2; + for (auto* h_task0 : handles_task0) { + handles_task2.push_back( + TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status)); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + + auto* h1_task2 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* matmul = nullptr; + if (func) { + const string matmul_device = remote_func_outputs ? task2_name : ""; + string function_def = MatMulFunction(matmul_device); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + matmul = TFE_NewOp(ctx, "MatMulFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(matmul, h0_task0, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(matmul, h1_task2, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } else { + // Handles are on task0 (local), and task2, but op is on task1. + matmul = MatMulOp(ctx, h0_task0, h1_task2); + } + if (remote) { + TFE_OpSetDevice(matmul, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } else if (!async) { + // Set the local device to CPU to easily validate mirroring + string cpu_device_name; + ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); + TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); + // The input handles should never change since they have been mirrored. + ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); + } + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(gjn): Add support for waiting on async local mirrors + if (!remote && !async && !remote_func_outputs) { + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); + // The input handles should never change since they have been mirrored. + ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); + } + + if (remote_func_outputs) { + const string backing_device = + TFE_TensorHandleBackingDeviceName(retvals[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + EXPECT_EQ(backing_device, task2_name); + } + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h1_task2); + TFE_DeleteTensorHandle(retvals[0]); + for (auto* h : handles_task0) { + TFE_DeleteTensorHandle(h); + } + for (auto* h : handles_task2) { + TFE_DeleteTensorHandle(h); + } + + TFE_DeleteOp(matmul); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteExecutor(executor); + if (func) { + TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); + } + TFE_DeleteContext(ctx); + + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} diff --git a/tensorflow/c/eager/c_api_remote_test_util.h b/tensorflow/c/eager/c_api_remote_test_util.h new file mode 100644 index 00000000000..08633689402 --- /dev/null +++ b/tensorflow/c/eager/c_api_remote_test_util.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ + +// Run a function containing a MatMul op and check its output. +// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one +// which creates a remote remote input, to simulate a scenario that the remote +// input is not ready when we start running an op or a function. +void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, + bool heavy_load_on_streaming_rpc, + bool remote_func_outputs = false); + +#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 724176505ba..fd208c6770d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include // clang-format off +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/platform/platform.h" // clang-format on @@ -876,89 +877,6 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -void Execute_MatMul_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* matmul = MatMulOp(ctx, m, m); - - TFE_OpSetXLACompilation(matmul, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - // Running a primitive TF operator via XLA is not yet supported. - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_DeleteOp(matmul); - TFE_DeleteTensorHandle(m); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - EXPECT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float product[4] = {0}; - EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); - memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(7, product[0]); - EXPECT_EQ(10, product[1]); - EXPECT_EQ(15, product[2]); - EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } -TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } - -void Execute_Min_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); - TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); - TFE_Op* minOp = MinOp(ctx, input, axis); - - TFE_OpSetXLACompilation(minOp, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(minOp, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteOp(minOp); - TFE_DeleteTensorHandle(input); - TFE_DeleteTensorHandle(axis); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float output[2] = {0}; - EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); - memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(1, output[0]); - EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } -TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } -#endif // TENSORFLOW_EAGER_USE_XLA - void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1274,6 +1192,68 @@ TEST(CAPI, StringAttributes) { TF_DeleteStatus(status); } +// Same test as above, expect use SetOpAttrValueScalar to set attrs. +TEST(CAPI, TestTFE_SetOpAttrs) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + tensorflow::AttrValue i_list_values; + for (int i = 0; i < 4; ++i) { + i_list_values.mutable_list()->add_i(1); + } + SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status); + SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status); + + tensorflow::AttrValue padding_value; + *padding_value.mutable_s() = "VALID"; + tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status); + + tensorflow::AttrValue data_format_value; + *data_format_value.mutable_s() = "NHWC"; + tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format", + status); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -1620,4 +1600,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { TFE_DeleteContext(ctx); } +// Needs to work with a const TFE_Op since custom devices should not modify the +// op they are called with. +TFE_Op* CloneOp(const TFE_Op* other) { + TF_Status* status = TF_NewStatus(); + TFE_Context* context = TFE_OpGetContext(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* op_name = TFE_OpGetName(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* ret = TFE_NewOp(context, op_name, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* device = TFE_OpGetDevice(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(ret, device, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other)); + int num_inputs = TFE_OpGetFlatInputCount(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + for (int input_index = 0; input_index < num_inputs; ++input_index) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(ret, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + } + TF_DeleteStatus(status); + return ret; +} + +TEST(CAPI, TestTFE_OpRecreation) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Clone an op with attributes and a device set. + TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64); + TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(original_var_op, + "/job:localhost/replica:0/task:0/device:CPU:0", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned = CloneOp(original_var_op); + + EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0", + std::string(TFE_OpGetDevice(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + int num_retvals = 1; + TFE_TensorHandle* ret; + TFE_Execute(cloned, &ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(ret); + + // Clone an op with inputs and no device set. + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); + TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInputList(original_identity, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned_identity = CloneOp(original_identity); + EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status))); + TFE_TensorHandle* identity_ret[] = {nullptr, nullptr}; + num_retvals = 2; + TFE_Execute(cloned_identity, identity_ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(identity_ret[0]); + TFE_DeleteTensorHandle(identity_ret[1]); + + TFE_DeleteOp(cloned_identity); + TFE_DeleteOp(original_identity); + TFE_DeleteOp(original_var_op); + TFE_DeleteOp(cloned); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 192f10533a6..fd68866f502 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -102,6 +102,32 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, return th; } +TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) { constexpr int64_t dims[] = {100, 100}; constexpr int num_elements = dims[0] * dims[1]; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index fcf407aa9c3..2f77ae5cf44 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -40,6 +40,14 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, float data[], int64_t dims[], int num_dims); +// Get a Matrix TensorHandle with given float values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims); + +// Get a Matrix TensorHandle with given int values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims); + // Return a tensor handle containing a 100x100 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 8408f7ef60f..2d290df19ce 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() { return *factories; } -static const char* default_factory = ""; +static tracing::FactoryFunction default_factory; void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { assert((!GetFactories().count(name)) || @@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { GetFactories()[name] = factory; } -void SetDefaultTracingEngine(const char* name) { default_factory = name; } - -static TracingContext* CreateTracingExecutionContext(const char* fn_name, - TF_Status* s) { - auto entry = GetFactories().find(default_factory); - if (entry != GetFactories().end()) return entry->second(fn_name, s); +Status SetDefaultTracingEngine(const char* name) { + auto entry = GetFactories().find(name); + if (entry != GetFactories().end()) { + default_factory = GetFactories().find(name)->second; + return Status::OK(); + } string msg = absl::StrCat( - "No tracing engine factory has been registered with the key '", - default_factory, "' (available: "); + "No tracing engine factory has been registered with the key '", name, + "' (available: "); // Ensure deterministic (sorted) order in the error message std::set factories_sorted; for (const auto& factory : GetFactories()) @@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name, } msg += ")"; - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return errors::InvalidArgument(msg.c_str()); +} + +static TracingContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + if (default_factory) { + return default_factory(fn_name, s); + } + Set_TF_Status_from_Status( + s, errors::FailedPrecondition("default_factory is nullptr")); return nullptr; } @@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext; using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; -void TF_SetTracingImplementation(const char* name) { - SetDefaultTracingEngine(name); +void TF_SetTracingImplementation(const char* name, TF_Status* s) { + Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); } // Creates a new TensorFlow function, it is an execution context attached to a diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index b66869b4290..d216b4e694b 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction; // This allows the client to swap the implementation of the tracing engine. // Any future call to TF_CreateFunction will use the implementation defined // here. -void TF_SetTracingImplementation(const char* name); +void TF_SetTracingImplementation(const char* name, TF_Status*); // Creates a new TensorFlow function. A Function is an execution context, and as // such it can trace operations through TF_ExecuteOperation. After completing diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 7bda3aed76d..0e9d6c18157 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation { return errors::FailedPrecondition( "GraphOperation::Reset must be called before calling SetOpName."); } - op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name)); + // TODO(b/145674566): We use Graph::NewName to get a unique name here but + // this may not be consistent with python's naming policy. + mutex_lock l(g_->mu); + op_.reset(new TF_OperationDescription(g_, op_type_.c_str(), + g_->graph.NewName(op_name).c_str())); return Status::OK(); } const string& Name() const override { return op_type_; } @@ -361,9 +365,10 @@ class GraphContext : public TracingContext { } auto s = TF_NewStatus(); - func->func = TF_GraphToFunction( - graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), - graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr, + inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), + nullptr, nullptr, name_.data(), s); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); *f = func.release(); @@ -387,7 +392,7 @@ class GraphContext : public TracingContext { private: std::unique_ptr graph_; std::vector inputs_; - const char* name_; + string name_; }; static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { @@ -397,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { // Register the tracing implemented in this file as the default tracing engine. static bool register_tracing = [] { RegisterTracingEngineFactory("graphdef", GraphTracingFactory); - SetDefaultTracingEngine("graphdef"); + SetDefaultTracingEngine("graphdef").IgnoreError(); return true; }(); diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index c00e04d98af..9433fe8f120 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -120,7 +120,7 @@ class TracingContext : public AbstractContext { }; typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); -void SetDefaultTracingEngine(const char* name); +Status SetDefaultTracingEngine(const char* name); void RegisterTracingEngineFactory(const ::tensorflow::string& name, FactoryFunction factory); } // namespace tracing diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index c56e8ab05fc..432ddb4b2d4 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -22,19 +22,30 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +using tensorflow::Status; using tensorflow::string; +using tensorflow::TF_StatusPtr; namespace tensorflow { namespace { +// The tests are parameterized on: +// - a string representing the tracing implementation: "mlir" or "graphdef". +// - a boolean that when true enables TFRT as the execution engine. class UnifiedCAPI : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; @@ -554,7 +565,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { auto* add_op = TF_NewAbstractOp(graph_ctx); TF_AbstractOpSetOpType(add_op, "Add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_AbstractOpSetOpName(add_op, "my_add1", s); + TF_AbstractOpSetOpName(add_op, "my_add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg0, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); @@ -576,7 +587,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { auto* add_op = TF_NewAbstractOp(graph_ctx); TF_AbstractOpSetOpType(add_op, "Add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_AbstractOpSetOpName(add_op, "my_add2", s); + TF_AbstractOpSetOpName(add_op, "my_add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg1, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); @@ -983,6 +994,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) { TF_DeleteExecutionContext(graph_ctx); } + +// The above tests are run for a combination of: +// - graphdef and MLIR tracing engine +// - Using TFRT as an execution runtime (true == enable TFRT) #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Combine(::testing::Values("graphdef", diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 1c078d4f42c..b058c79a17b 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context); ASSERT_FALSE(arrived); @@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) { bool executed = false; const char* custom_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom_device_name, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), custom0, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom0, + /*strict_scope_placement=*/false, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), custom1, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom1, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1)); - // Custom device: mix of custom/physical fails. + // Custom device: mix of custom/physical places the op on the custom device. matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); num_retvals = 1; + executed = false; TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); - ASSERT_NE(TF_OK, TF_GetCode(status.get())); - ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); - ASSERT_TRUE( - absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull + EXPECT_TRUE(executed); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteTensorHandle(retval); + + // Explicit placement still forces the op onto the requested device + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); + TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); + + // Custom devices can refuse to do type-based dispatch (as hcustom1 is + // configured to do) + matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get())); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); } TEST(CUSTOM_DEVICE, InvalidRegistrationError) { @@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed, + RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) << TF_Message(status.get()); const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) - << TF_Message(status.get()); - - RegisterLoggingDevice(context.get(), - "/job:localhost/replica:0/task:0/device:CPU:0", + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) << TF_Message(status.get()); + + RegisterLoggingDevice( + context.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) + << TF_Message(status.get()); } diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc index 28de3665653..014abe38368 100644 --- a/tensorflow/c/eager/custom_device_testutil.cc +++ b/tensorflow/c/eager/custom_device_testutil.cc @@ -33,6 +33,9 @@ struct LoggingDevice { bool* arrived_flag; // Set to true whenever an operation is executed bool* executed_flag; + // If true, only explicit op placements are accepted. If false, uses + // type-based dispatch. + bool strict_scope_placement; }; struct LoggedTensor { @@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, return nullptr; } -void LoggingDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + LoggingDevice* dev = reinterpret_cast(device_info); + if (dev->strict_scope_placement && *requested_placement == '\0') { + TF_SetStatus(s, TF_INTERNAL, + "Ops must be placed on the device explicitly, or their inputs " + "first copied to other devices."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + TFE_Op* op(TFE_NewOp(context, operation_name, s)); if (TF_GetCode(s) != TF_OK) return; TFE_OpAddAttrs(op, attributes); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); + if (TF_GetCode(s) != TF_OK) return; + int num_inputs = TFE_OpGetFlatInputCount(original_op, s); + if (TF_GetCode(s) != TF_OK) return; for (int j = 0; j < num_inputs; ++j) { - TFE_TensorHandle* input = inputs[j]; + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s); + if (TF_GetCode(s) != TF_OK) return; const char* input_device = TFE_TensorHandleDeviceName(input, s); if (TF_GetCode(s) != TF_OK) return; if (dev->device_name == input_device) { @@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) { } // namespace void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status) { + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status) { TFE_CustomDevice custom_device; custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; @@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name, device->executed_flag = executed_flag; device->device_name = name; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + device->strict_scope_placement = strict_scope_placement; TFE_RegisterCustomDevice(context, custom_device, name, device, status); } @@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag, logging_device->device_name = name; logging_device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + logging_device->strict_scope_placement = true; *device_info = reinterpret_cast(logging_device); } diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h index 509df7d3e3e..a7c60080adf 100644 --- a/tensorflow/c/eager/custom_device_testutil.h +++ b/tensorflow/c/eager/custom_device_testutil.h @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/c/tf_status.h" void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status); + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status); void AllocateLoggingDevice(const char* name, bool* arrived_flag, bool* executed_flag, TFE_CustomDevice** device, void** device_info); diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 45048bd6efb..30d2009dc6a 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -109,7 +109,8 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { // Gets DLPack's DLContext from eager tensor handle. DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; - const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status); + const char* device_name = + tensorflow::unwrap(h)->BackingDeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc new file mode 100644 index 00000000000..640edc7228a --- /dev/null +++ b/tensorflow/c/eager/gradient_checker.cc @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradient_checker.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +using namespace std; + +// ================== Helper functions ================= + +// Fills data with values [start,end) with given step size. +void Range(vector* data, int start, int end, int step = 1) { + for (int i = start; i < end; i += step) { + (*data)[i] = i; + } +} + +// Returns AbstractTensorHandlePtr containing [0, ..., n-1]. +AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) { + vector vals(n); + int64_t vals_shape[] = {n}; + Range(&vals, 0, n); + AbstractTensorHandlePtr r = + GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1); + return r; +} + +// Fills out_dims with the dimensions of the given tensor. +void GetDims(const TF_Tensor* t, int64_t* out_dims) { + int num_dims = TF_NumDims(t); + for (int i = 0; i < num_dims; i++) { + out_dims[i] = TF_Dim(t, i); + } +} + +// Runs model as is if output is a scalar, +// else sums the output tensor before returning. +Status RunAndMaybeSum(AbstractContext* ctx, Model forward, + absl::Span inputs, + absl::Span outputs, + bool use_function) { + GradientRegistry registry; + std::vector model_outputs(1); + + // Run the model. + TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs, + absl::MakeSpan(model_outputs), use_function, + registry)); + AbstractTensorHandle* model_out = model_outputs[0]; + + TF_Tensor* model_out_tensor; + TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor)); + int num_dims_out = TF_NumDims(model_out_tensor); + + // If the output is a scalar, then return the scalar output + if (num_dims_out == 0) { + outputs[0] = model_out; + return Status::OK(); + } + + // Else, reduce sum the output to get a scalar + + // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1]. + AbstractTensorHandlePtr sum_dims = + GetRangeTensorHandleUtil(ctx, num_dims_out); + + // Reduce sum the output on all dimensions. + std::vector sum_inputs(2); + sum_inputs[0] = model_out; + sum_inputs[1] = sum_dims.get(); + + TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs), + absl::MakeSpan(model_outputs), "sum_output")); + outputs[0] = model_outputs[0]; + return Status::OK(); +} +// ========================= End Helper Functions============================== + +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad) { + AbstractTensorHandle* theta = + inputs[input_index]; // parameter we are grad checking + + // Convert from AbstractTensor to TF_Tensor. + TF_Tensor* theta_tensor; + TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor)); + + // Get number of elements and fill data. + int num_elems = TF_TensorElementCount(theta_tensor); + vector theta_data(num_elems); + memcpy(theta_data.data(), TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + + // Initialize space for the numerical gradient. + vector dtheta_approx(num_elems); + + // Get theta shape and store in theta_dims. + int num_dims = TF_NumDims(theta_tensor); + vector theta_dims(num_dims); + GetDims(theta_tensor, theta_dims.data()); + + // Initialize auxilary data structures. + vector thetaPlus_data(num_elems); + vector thetaMinus_data(num_elems); + std::vector f_outputs(1); + + // Numerical Grad Check + for (int i = 0; i < num_elems; i++) { + // Get relative epsilon value + float epsilon = + std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0 + AbstractTensorHandlePtr two_eps = + GetScalarTensorHandleUtil(ctx, 2 * epsilon); + + // Initialize theta[i] + epsilon. + memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + thetaPlus_data[i] += epsilon; + AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat( + ctx, thetaPlus_data.data(), theta_dims.data(), num_dims); + + // Initialize theta[i] - epsilon. + memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor), + TF_TensorByteSize(theta_tensor)); + thetaMinus_data[i] -= epsilon; + AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat( + ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); + + // Get f(theta + eps): + inputs[input_index] = thetaPlus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + absl::MakeSpan(f_outputs), use_function)); + AbstractTensorHandle* fPlus = f_outputs[0]; + + // Get f(theta - eps): + inputs[input_index] = thetaMinus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + absl::MakeSpan(f_outputs), use_function)); + AbstractTensorHandle* fMinus = f_outputs[0]; + + // Take Difference of both estimates: (f(theta + eps) - f(theta - eps)). + TF_RETURN_IF_ERROR( + ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top")); + AbstractTensorHandle* fDiff = f_outputs[0]; + + // Calculate using the difference quotient definition: + // (f(theta + eps) - f(theta - eps)) / (2 * eps). + TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()}, + absl::MakeSpan(f_outputs), + "diff_quotient")); + AbstractTensorHandle* diff_quotient = f_outputs[0]; + + TF_Tensor* grad_tensor; + TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor)); + float grad_data[1]; + memcpy(&grad_data[0], TF_TensorData(grad_tensor), + TF_TensorByteSize(grad_tensor)); + + dtheta_approx[i] = grad_data[0]; + } + + // Populate *numerical_grad with the data from dtheta_approx. + TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( + ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h new file mode 100644 index 00000000000..8497f5af48e --- /dev/null +++ b/tensorflow/c/eager/gradient_checker.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +/* Returns numerical grad inside `dtheta_approx` given `forward` model and + * parameter specified by `input_index`. + * + * I.e. if y = and w = inputs[input_index], + * this will calculate dy/dw numerically. + * + * `use_function` indicates whether to use graph mode(true) or eager(false). + * + * `numerical_grad` is the pointer to the AbstractTensorHandle* which will + * hold the numerical gradient data at the end of the function. + */ +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad); + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc new file mode 100644 index 00000000000..7a438085fb5 --- /dev/null +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -0,0 +1,265 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradient_checker.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/eager/mnist_gradients_testutil.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +class GradientCheckerTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); + TF_RETURN_IF_ERROR( + registry->Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); + return Status::OK(); +} + +TEST_P(GradientCheckerTest, TestGradCheckMatMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; + int64_t B_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); + + std::vector inputs; + inputs.push_back(A.get()); + inputs.push_back(B.get()); + + AbstractTensorHandle* grad_approx; + Status s = CalcNumericalGrad( + ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &grad_approx); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(grad_approx, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; + float tolerance = 1e-2; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(expected_dA[j], result_data[j], tolerance); + } + TF_DeleteTensor(gt); +} + +TEST_P(GradientCheckerTest, TestGradCheckMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + // Will perform z = x*y. + // dz/dx = y + + std::vector inputs; + inputs.push_back(x.get()); + inputs.push_back(y.get()); + AbstractTensorHandle* g; + + Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs), + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &g); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float result_data[1] = {0}; + memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2); + TF_DeleteTensor(gt); +} + +TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + /** Test to show how to use this API with analytical gradients: + * + * We have `SoftmaxLossGradModel`, which is a wrapper for the + * Softmax analytical gradient found in c/experimental/nn_grads. + * + * We will use the GradientChecker by applying finite differences + * to the forward pass wrapped in `SoftmaxModel` and verify that + * both the analytical and numerical gradients are relatively + * close. + * + */ + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = scores + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // y = labels + int y_vals[] = {1, 0, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + std::vector inputs; + inputs.push_back(X.get()); + inputs.push_back(y.get()); + + // Run analytical gradient and get its data. + std::vector outputs(2); + s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs), + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float danalytical[9] = {0}; // Contains data from analytical gradient. + memcpy(&danalytical[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + // Run numerical gradient approximation using the GradientChecker API. + AbstractTensorHandle* g; // Will contain numerical approximation data. + s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs), + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), &g); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* gt; + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + float dnumerical[9] = {0}; + memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt)); + + // Now compare the two implementations: + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); + } + + // Only Unref() first output as 2nd is nullptr grad for labels + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); + TF_DeleteTensor(gt); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, GradientCheckerTest, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, GradientCheckerTest, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 406da1291ae..89ff140fa73 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/gradients.h" #include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" @@ -23,25 +24,97 @@ limitations under the License. namespace tensorflow { namespace gradients { -Status GradientRegistry::Register(const string& op_name, - GradientFunctionFactory factory) { +namespace { +Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, + AbstractTensorHandle** result) { + AbstractOperationPtr op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); + if (isa(op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(op.get())->SetOpName( + absl::StrCat("ZerosLike", ToId(t)).c_str())); + } + TF_RETURN_IF_ERROR(op->AddInput(t)); + int num_outputs = 1; + std::vector outputs(num_outputs); + TF_RETURN_IF_ERROR( + op->Execute(absl::Span(outputs), &num_outputs)); + *result = outputs[0]; + return Status::OK(); +} +} // namespace + +class IncomingGradientsImpl : public IncomingGradients { + public: + explicit IncomingGradientsImpl( + absl::Span grad_inputs, Context* ctx, + DefaultGradientFunction* default_gradients) + : grad_inputs_(grad_inputs), + ctx_(ctx), + default_gradients_(default_gradients) {} + AbstractTensorHandle* operator[](int i) const override { + return default_gradients_->get(ctx_, grad_inputs_, i); + } + size_t size() const override { return grad_inputs_.size(); } + + private: + absl::Span grad_inputs_; + Context* ctx_; + DefaultGradientFunction* default_gradients_; +}; + +AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op) + : outputs_(op.outputs) { + for (auto output : outputs_) { + output->Ref(); + } +} +AbstractTensorHandle* AllZerosDefaultGradients::get( + Context* ctx, absl::Span grad_inputs, int i) { + if (grad_inputs[i]) { + return grad_inputs[i]; + } + if (cached_default_grads_[i]) { + return cached_default_grads_[i].get(); + } + AbstractTensorHandle* result = nullptr; + Status s = ZerosLike(ctx->ctx, outputs_[i], &result); + if (!s.ok()) { + if (result) { + result->Unref(); + } + VLOG(1) << "Failed to create ZerosLike for index " << i; + return nullptr; + } + cached_default_grads_[i].reset(result); + return result; +} + +PassThroughDefaultGradients::PassThroughDefaultGradients( + const ForwardOperation& op) {} +AbstractTensorHandle* PassThroughDefaultGradients::get( + Context* ctx, absl::Span grad_inputs, int i) { + return grad_inputs[i]; +} + +Status GradientRegistry::Register( + const string& op_name, BackwardFunctionFactory backward_function_factory) { auto iter = registry_.find(op_name); if (iter != registry_.end()) { const string error_msg = "Gradient already exists for op: " + op_name + "."; return errors::AlreadyExists(error_msg); } - registry_.insert({op_name, factory}); + registry_.insert({op_name, backward_function_factory}); return Status::OK(); } Status GradientRegistry::Lookup( const ForwardOperation& op, - std::unique_ptr* grad_fn) const { + std::unique_ptr* backward_function) const { auto iter = registry_.find(op.op_name); if (iter == registry_.end()) { const string error_msg = "No gradient defined for op: " + op.op_name + "."; return errors::NotFound(error_msg); } - grad_fn->reset(iter->second(op)); + backward_function->reset(iter->second(op)); return Status::OK(); } @@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const { } return outputs[0]; } -AbstractTensorHandle* TapeTensor::ZerosLike() const { - AbstractOperationPtr op(ctx_->CreateOperation()); - // TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR. - Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr); - if (!s.ok()) { - return nullptr; - } - if (isa(op.get())) { - s = dyn_cast(op.get())->SetOpName( - absl::StrCat("ZerosLike", ToId(handle_)).c_str()); - if (!s.ok()) { - return nullptr; - } - } - s = op->AddInput(handle_); - if (!s.ok()) { - return nullptr; - } - int num_outputs = 1; - // TODO(srbs): Figure out who is in charge of releasing this. - std::vector outputs(num_outputs); - s = op->Execute(absl::Span(outputs), &num_outputs); - if (!s.ok()) { - return nullptr; - } - return outputs[0]; -} + +AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; } // Returns the number of elements in the gradient tensor. int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const { @@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients( // Calls the passed-in backward function. Status TapeVSpace::CallBackwardFunction( - GradientFunction* backward_function, + BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, std::vector* result) const { if (backward_function == nullptr) return Status::OK(); Context ctx = {ctx_}; - return backward_function->Compute(&ctx, output_gradients, result); + IncomingGradientsImpl incoming_gradients( + output_gradients, &ctx, backward_function->GetDefaultGradientFunction()); + return backward_function->GetGradientFunction()->Compute( + &ctx, incoming_gradients, result); } // Looks up the ID of a Gradient. @@ -191,6 +242,7 @@ namespace internal { Status Reset(AbstractOperation* op_, const char* op, const char* raw_device_name, ForwardOperation* forward_op_) { forward_op_->op_name = op; + forward_op_->attrs.Reset(op); return op_->Reset(op, raw_device_name); } Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, @@ -363,21 +415,30 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, input_ids[i] = ToId(forward_op_->inputs[i]); input_dtypes[i] = forward_op_->inputs[i]->DataType(); } + for (int i = 0; i < *num_retvals; i++) { + // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. + forward_op_->outputs.push_back(retvals[i]); + } + // TODO(b/166669239): This is needed to support AttrBuilder::Get for string + // attributes. Number type attrs and DataType attrs work fine without this. + // Consider getting rid of this and making the behavior between number types + // and string consistent. + forward_op_->attrs.BuildNodeDef(); std::vector tape_tensors; for (auto t : retvals) { tape_tensors.push_back(TapeTensor(t, ctx)); } tape->RecordOperation( op_->Name(), tape_tensors, input_ids, input_dtypes, - [registry, forward_op_]() -> GradientFunction* { - std::unique_ptr grad_fn; - Status s = registry.Lookup(*forward_op_, &grad_fn); + [registry, forward_op_]() -> BackwardFunction* { + std::unique_ptr backward_fn; + Status s = registry.Lookup(*forward_op_, &backward_fn); if (!s.ok()) { return nullptr; } - return grad_fn.release(); + return backward_fn.release(); }, - [](GradientFunction* ptr) { + [](BackwardFunction* ptr) { if (ptr) { delete ptr; } diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index 267ee5b7ab2..04e11291404 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -55,18 +55,25 @@ struct Context { public: AbstractContext* ctx; }; + +class IncomingGradients { + public: + virtual AbstractTensorHandle* operator[](int i) const = 0; + virtual size_t size() const = 0; + virtual ~IncomingGradients() {} +}; + class GradientFunction { public: // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in // `grad_inputs`. - virtual Status Compute(Context* ctx, - absl::Span grad_inputs, + virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs, std::vector* grad_outputs) = 0; virtual ~GradientFunction() {} }; // Metadata from the forward operation that is made available to the -// gradient registerer to instantiate a GradientFunction. +// gradient registerer to instantiate a BackwardFunction. struct ForwardOperation { public: string op_name; @@ -76,18 +83,86 @@ struct ForwardOperation { AbstractContext* ctx; }; -using GradientFunctionFactory = - std::function; - -// Map from op name to a `GradientFunctionFactory`. -class GradientRegistry { +// Interface for building default zeros gradients for op outputs which are +// missing incoming gradients. Custom implementations of this can be used to +// control which of the forward op's output tensors/their metadata needs to +// be kept around in memory to build the default zeros grad. +// +// Some common helper implementations are provided below. +class DefaultGradientFunction { public: - Status Register(const string& op, GradientFunctionFactory factory); - Status Lookup(const ForwardOperation& op, - std::unique_ptr* grad_fn) const; + virtual AbstractTensorHandle* get( + Context* ctx, absl::Span grad_inputs, + int i) = 0; + virtual ~DefaultGradientFunction() {} +}; + +// Returns zeros for any `nullptr` in `grad_inputs`. +// +// This may require keeping track of all of forward op's output +// tensors and hence may incur a higher memory footprint. Use sparingly. +// +// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor +// handle. +// +// The destructor of this class `Unref`'s any cached tensor handles so users of +// those tensor handles should `Ref` them in order to keep them alive if needed. +class AllZerosDefaultGradients : public DefaultGradientFunction { + public: + explicit AllZerosDefaultGradients(const ForwardOperation& op); + AbstractTensorHandle* get(Context* ctx, + absl::Span grad_inputs, + int i) override; private: - absl::flat_hash_map registry_; + // TODO(srbs): We do not always need to keep the tensors around. In immediate + // execution mode we just need to store the shape and dtype. During tracing + // we may need to keep the tensor around if the shape is not full defined. + std::vector outputs_; + std::vector cached_default_grads_; +}; + +// Passes through `grad_inputs` as-is. The `GradientFunction` +// will be expected to deal with nullptr in `grad_inputs` if any. +class PassThroughDefaultGradients : public DefaultGradientFunction { + public: + explicit PassThroughDefaultGradients(const ForwardOperation& op); + AbstractTensorHandle* get(Context* ctx, + absl::Span grad_inputs, + int i) override; +}; + +// A `BackwardFunction` wraps a `GradientFunction` and a +// `DefaultGradientFunction`. Both are owned by this class' instance. +class BackwardFunction { + public: + BackwardFunction(GradientFunction* gradient_function, + DefaultGradientFunction* default_gradients) + : gradient_function_(gradient_function), + default_gradients_(default_gradients) {} + GradientFunction* GetGradientFunction() { return gradient_function_.get(); } + DefaultGradientFunction* GetDefaultGradientFunction() { + return default_gradients_.get(); + } + + private: + std::unique_ptr gradient_function_; + std::unique_ptr default_gradients_; +}; + +using BackwardFunctionFactory = + std::function; + +// Map from op name to a `BackwardFunctionFactory`. +class GradientRegistry { + public: + Status Register(const string& op, + BackwardFunctionFactory backward_function_factory); + Status Lookup(const ForwardOperation& op, + std::unique_ptr* backward_function) const; + + private: + absl::flat_hash_map registry_; }; // Returns a unique id for the tensor which is used by the tape to build @@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t); // allow us to trace the data dependencies between operations and hence compute // gradients. // -// This also implements `ZerosLike` and `OnesLike` to create the default +// This also implements `OnesLike` to create the default // incoming gradients for tensors which do not already have an incoming // gradient. +// +// `ZerosLike` is not expected to be called and returns a nullptr. The creation +// of default zeros grads is handled by the `DefaultGradientFunction` registered +// for each op. +// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy. +// Figure out a way to avoid this. +// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? class TapeTensor { public: TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx); @@ -123,7 +205,7 @@ class TapeTensor { private: AbstractTensorHandle* handle_; - // The context where OnesLike and ZerosLike ops are to be created. + // The context where OnesLike ops are to be created. AbstractContext* ctx_; }; @@ -132,7 +214,7 @@ class TapeTensor { // gradient and for performing gradient aggregation. // See `tensorflow::eager::VSpace` for more details. class TapeVSpace - : public eager::VSpace { + : public eager::VSpace { public: explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {} ~TapeVSpace() override {} @@ -147,7 +229,7 @@ class TapeVSpace // Calls the passed-in backward function. Status CallBackwardFunction( - GradientFunction* backward_function, + BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, std::vector* result) const override; @@ -168,8 +250,14 @@ class TapeVSpace }; // A tracing/immediate-execution agnostic tape. +// +// Gradient functions defined for this library support handling null incoming +// gradients. `Tape::ComputeGradient` should be called with +// `build_default_zeros_grads=false`. Calling with +// `build_default_zeros_grads=true` (the default) is equivalent but just results +// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway. using Tape = tensorflow::eager::GradientTape; + BackwardFunction, TapeTensor>; } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index e02f189c3d2..3aedf55e97a 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/tf_status_helper.h" @@ -35,17 +37,26 @@ namespace tensorflow { namespace gradients { namespace internal { namespace { +using std::vector; +using tensorflow::TF_StatusPtr; +using tracing::TracingOperation; class CppGradients : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; Status RegisterGradients(GradientRegistry* registry) { - return registry->Register("Add", AddRegisterer); + TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); + return Status::OK(); } // Computes `inputs[0] + inputs[1]` and records it on the tape. @@ -58,9 +69,9 @@ Status Add(AbstractContext* ctx, Tape* tape, forward_op.ctx = ctx; TF_RETURN_IF_ERROR( Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(add_op.get())) { + if (isa(add_op.get())) { TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName("my_add")); + dyn_cast(add_op.get())->SetOpName("my_add")); } TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); @@ -69,6 +80,46 @@ Status Add(AbstractContext* ctx, Tape* tape, registry); } +// Computes `exp(inputs[0])` and records it on the tape. +Status Exp(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr exp_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(exp_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(exp_op.get())->SetOpName("my_exp")); + } + TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op)); + int num_retvals = 1; + return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `IdentityN(inputs)` and records it on the tape. +Status IdentityN(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr identity_n_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(identity_n_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(identity_n_op.get()) + ->SetOpName("my_identity_n")); + } + TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op)); + int num_retvals = outputs.size(); + return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op, + tape, registry); +} + // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -91,7 +142,8 @@ Status AddGradModel(AbstractContext* ctx, vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads)); + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); for (auto add_output : add_outputs) { add_output->Unref(); } @@ -101,6 +153,71 @@ Status AddGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// y = exp(inputs[0]) +// return grad(y, {inputs[0]}) +Status ExpGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + std::vector exp_outputs(1); + TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs), + registry)); // Compute x+y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto exp_output : exp_outputs) { + exp_output->Unref(); + } + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + +// Computes +// ignored, y = IdentityN(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +// This should return [nullptr, 1]. +Status IdentityNGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); + tape->Watch(ToId(inputs[1])); + + vector identity_n_outputs(2); + TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs, + absl::MakeSpan(identity_n_outputs), registry)); + + std::unordered_map + source_tensors_that_are_targets; + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto identity_n_output : identity_n_outputs) { + identity_n_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + AbstractContext* BuildFunction(const char* fn_name) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -132,26 +249,42 @@ Status RunModel(Model model, AbstractContext* ctx, if (use_function) { const char* fn_name = "test_fn"; std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; { AbstractContextPtr func_ctx(BuildFunction(fn_name)); std::vector func_inputs; func_inputs.reserve(inputs.size()); TF_RETURN_IF_ERROR( CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); - OutputList output_list; - output_list.expected_num_outputs = outputs.size(); - output_list.outputs.resize(outputs.size()); + vector model_outputs; + model_outputs.resize(outputs.size()); TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), - absl::MakeSpan(output_list.outputs), registry)); + absl::MakeSpan(model_outputs), registry)); for (auto func_input : func_inputs) { func_input->Unref(); } AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) ->Finalize(&output_list, &func)); scoped_func.reset(func); - output_list.outputs[0]->Unref(); - output_list.outputs[1]->Unref(); + for (auto output : output_list.outputs) { + output->Unref(); + } TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); } @@ -160,8 +293,19 @@ Status RunModel(Model model, AbstractContext* ctx, for (auto input : inputs) { TF_RETURN_IF_ERROR(fn_op->AddInput(input)); } - int retvals = outputs.size(); - TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals)); + int retvals = outputs.size() - null_indices.size(); + vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); return Status::OK(); } else { @@ -264,18 +408,172 @@ TEST_P(CppGradients, TestAddGrad) { TF_DeleteTensor(result_tensor); } -// TODO(b/160888630): Enable this test with mlir after AddInputList is -// supported. It is needed for AddN op which is used for gradient aggregation. +TEST_P(CppGradients, TestExpGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = exp(x) + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_NEAR(*result_value, 2.718, 0.001); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +TEST_P(CppGradients, TestIdentityNGrad) { + // Pseudo-code: + // + // tape.watch(x1) + // tape.watch(x2) + // unused, y = IdentityN([x1, x2]) + // outputs = tape.gradient(y, [x1, x2]) + // Expected: [nullptr, 1] + // + // This test is interesting because the current implementation of GradientTape + // would return [0, 1] whereas we use build_default_zeros_grads=false here + // so we get back [nullptr, 1]. + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x1; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x1.reset(x_raw); + } + AbstractTensorHandlePtr x2; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x2.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + std::vector outputs(2); + s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + EXPECT_EQ(outputs[0], nullptr); + TF_Tensor* result_tensor; + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +TEST_P(CppGradients, TestSetAttrString) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr t; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + t.reset(x_raw); + } + + AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx.get(); + Status s = Reset(check_numerics_op.get(), "CheckNumerics", + /*raw_device_name=*/nullptr, &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + if (isa(check_numerics_op.get())) { + s = dyn_cast(check_numerics_op.get()) + ->SetOpName("check_numerics"); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + } + s = AddInput(check_numerics_op.get(), t.get(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + string message = "This is the way!"; + s = SetAttrString(check_numerics_op.get(), "message", message.data(), + message.length(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + int num_retvals = 1; + std::vector outputs(1); + GradientRegistry registry; + std::unique_ptr tape(new Tape(/*persistent=*/false)); + s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), + &num_retvals, &forward_op, tape.get(), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + string read_message; + s = forward_op.attrs.Get("message", &read_message); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(read_message, message); +} + +// TODO(b/164171226): Enable this test with tfrt after AddInputList is +// supported. It is needed for IdentityN. #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), - /*tfrt*/ ::testing::Values(true, false), + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), + ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #endif diff --git a/tensorflow/c/eager/gradients_util.cc b/tensorflow/c/eager/gradients_util.cc new file mode 100644 index 00000000000..e53faf4a3f3 --- /dev/null +++ b/tensorflow/c/eager/gradients_util.cc @@ -0,0 +1,317 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradients_util.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace gradients { + +using namespace std; + +Status ScalarTensorHandleHelper(TFE_Context* ctx, float value, + TFE_TensorHandle** result) { + float data[] = {value}; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims, + TFE_TensorHandle** result) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims, + TFE_TensorHandle** result) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get()); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); + *result = th; + TF_DeleteTensor(t); + return StatusFromTF_Status(status.get()); +} + +// Get a scalar TensorHandle with given value +Status ScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given float values and dimensions +Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, + num_dims, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given int values and dimensions +Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], + int num_dims, AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager; + TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims, + num_dims, &input_eager)); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return StatusFromTF_Status(status.get()); +} + +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val) { + AbstractTensorHandlePtr y; + AbstractTensorHandle* y_raw = nullptr; + Status s = ScalarTensorHandle(ctx, val, &y_raw); + if (s.ok()) { + y.reset(y_raw); + } + return y; +} + +Status UpdateWeights(AbstractContext* ctx, vector& grads, + vector& weights, + AbstractTensorHandle* learning_rate) { + /* Update weights one by one using gradient update rule: + * + * w -= lr*grad[w] + * + * NOTE: assuming learning rate is positive + */ + + int num_grads = grads.size(); + vector temp_outputs(1); + std::string update_str; + + // Negate learning rate for gradient descent + TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, + absl::MakeSpan(temp_outputs), + "neg_lr")); // Compute -lr + learning_rate = temp_outputs[0]; + + for (int i = 0; i < num_grads; i++) { + // Compute dW = -lr * grad(w[i]) + update_str = "update_mul_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + AbstractTensorHandle* dW = temp_outputs[0]; + + // Compute temp = weights[i] + dW + update_str = "update_add_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + // Update the weights + weights[i] = temp_outputs[0]; + } + + return Status::OK(); +} + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + vector model_outputs; + model_outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(model_outputs), registry)); + for (auto func_input : func_inputs) { + func_input->Unref(); + } + AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + for (auto output : output_list.outputs) { + output->Unref(); + } + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size() - null_indices.size(); + vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs, registry); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/c/eager/gradients_util.h b/tensorflow/c/eager/gradients_util.h new file mode 100644 index 00000000000..cd0bbc0720d --- /dev/null +++ b/tensorflow/c/eager/gradients_util.h @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gradients { + +// Get a scalar TensorHandle with given value +Status ScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given float values and dimensions +Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given int values and dimensions +Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], + int num_dims, AbstractTensorHandle** tensor); + +// Places data from `t` into *result_tensor. +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data. +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val); + +// Performs gradient update for each weight using given learning rate. +Status UpdateWeights(AbstractContext* ctx, + std::vector& grads, + std::vector& weights, + AbstractTensorHandle* learning_rate); + +using Model = std::function, + absl::Span, const GradientRegistry&)>; + +// Runs given model in either graph or eager mode depending on value of +// use_function. +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry); + +// Builds context and returns inside *ctx. +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 6d06d9a8de6..02a3320ef65 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -57,15 +57,10 @@ class ImmediateExecutionContext : public AbstractContext { // Create a tensor instance from the given data buffer and description. // `memory_releaser` will be called on destruction, and it's responsible for - // cleaning up the underlying buffer. `convert_string` indicates whether it - // has to handle tstring conversion. Expected to be removed once tstring - // migration is done. - virtual AbstractTensorInterface* CreateTensor(DataType dtype, - const int64_t* dims, - int num_dims, void* data, - size_t len, bool convert_string, - MemoryReleaser memory_releaser, - void* memory_releaser_arg) = 0; + // cleaning up the underlying buffer. + virtual AbstractTensorInterface* CreateTensor( + DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, + MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0; // Create a handle to wrap and manage a Tensor virtual ImmediateExecutionTensorHandle* CreateLocalHandle( diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index ee212b21a96..7b68ec2c9f4 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation { virtual Status InputLength(const char* input_name, int* length) = 0; virtual Status OutputLength(const char* output_name, int* length) = 0; - // Experimental - virtual Status SetUseXla(bool enable) = 0; - // Set stack trace to be used for potential async error reporting. virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc new file mode 100644 index 00000000000..4114f50a798 --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -0,0 +1,723 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/eager/mnist_gradients_testutil.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { +using tensorflow::TF_StatusPtr; + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); + TF_RETURN_IF_ERROR( + registry->Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); + return Status::OK(); +} + +TEST_P(CppGradients, TestMatMulGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; + int64_t B_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(A) + * tape.watch(B) + * Y = AB + * outputs = tape.gradient(Y, [A, B]) + */ + + std::vector outputs(2); + s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dA[j], tolerance); + } + + TF_Tensor* dB_tensor; + s = GetValue(outputs[1], &dB_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(dB_tensor), + TF_TensorByteSize(dB_tensor)); + + float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f}; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dB[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(dA_tensor); + TF_DeleteTensor(dB_tensor); +} + +TEST_P(CppGradients, TestMNISTForward) { + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t dims_y[] = {2}; + num_dims = sizeof(dims_y) / sizeof(dims_y[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims); + + GradientRegistry registry; + + // Run the Forward Pass + std::vector outputs(2); + Status s = + RunModel(MNISTForwardModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } + + TF_Tensor* loss_vals_tensor; + s = GetValue(outputs[1], &loss_vals_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), + TF_TensorByteSize(loss_vals_tensor)); + float expected_losses[2] = {9.6f, 27.2f}; + for (int j = 0; j < 2; j++) { + ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(scores_tensor); + TF_DeleteTensor(loss_vals_tensor); +} + +TEST_P(CppGradients, TestMNISTForward2) { + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t X_dims[] = {3, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + + // Run the Forward Pass + std::vector outputs(2); + Status s = + RunModel(MNISTForwardModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[6] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 6; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } + + TF_Tensor* loss_vals_tensor; + s = GetValue(outputs[1], &loss_vals_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), + TF_TensorByteSize(loss_vals_tensor)); + float expected_losses[3] = {9.6f, 27.2f, 44.8f}; + for (int j = 0; j < 3; j++) { + ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(scores_tensor); + TF_DeleteTensor(loss_vals_tensor); +} + +TEST_P(CppGradients, TestMatMulTranspose) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t X_dims[] = {2, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + GradientRegistry registry; + + // Run the MatMul Op + std::vector outputs(1); + + Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[6] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 6; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } +} + +TEST_P(CppGradients, TestReluGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(X) + * Y = Relu(X) + * outputs = tape.gradient(Y, [X]) + */ + std::vector outputs(1); + s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[9] = {0}; + memcpy(&result_data[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); + } + + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); +} + +TEST_P(CppGradients, TestSoftmaxLossGrad) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = scores + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // y = labels + int y_vals[] = {1, 0, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(X) + * tape.watch(labels) + * loss = SoftmaxLoss(X, labels) + * outputs = tape.gradient(loss, [X, labels]) + * + * + */ + + std::vector outputs(2); + s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[9] = {0}; + memcpy(&result_data[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f, + 0.6652f, 0.8437f, -0.8858f, 0.0420f}; + float tolerance = 1e-3; + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); + } + + // Only Unref() first output as 2nd is nullptr grad for labels + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); +} + +TEST_P(CppGradients, TestMNISTGrad) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t X_dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t y_dims[] = {2}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + // Register Grads + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * + * tape.watch(W1) + * tape.watch(W2) + * mm = X*W1 + * hidden = Relu(mm) + * scores = W2*hidden + * loss = SoftmaxLoss(scores, y) + * outputs = tape.gradient(loss, [A, B]) + * + */ + + std::vector outputs(3); + s = RunModel(MNISTGradModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float tolerance = 1e-3; + TF_Tensor* dW1_tensor; + s = GetValue(outputs[0], &dW1_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dW1_tensor), + TF_TensorByteSize(dW1_tensor)); + + float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f}; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance); + } + + TF_Tensor* dW2_tensor; + s = GetValue(outputs[1], &dW2_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(dW2_tensor), + TF_TensorByteSize(dW2_tensor)); + + float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + outputs[2]->Unref(); + TF_DeleteTensor(dW1_tensor); + TF_DeleteTensor(dW2_tensor); +} + +TEST_P(CppGradients, TestScalarMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr eta; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + eta.reset(x_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + + GradientRegistry registry; + std::vector outputs(1); + Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float tolerance = 1e-3; + float eta_val = 1.5f; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance); + } + + outputs[0]->Unref(); + TF_DeleteTensor(dA_tensor); +} + +TEST_P(CppGradients, TestMNIST_Training) { + bool use_function = !std::get<2>(GetParam()); + if (use_function) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t X_dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // TODO(amturati): use random initializer for weights instead of + // constant values. + + // W1 = first weights + float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t y_dims[] = {2}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + // Register Grads + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Prepare for training + std::vector weights; + weights.push_back(W1.get()); + weights.push_back(W2.get()); + + // Set learning rate to be 1e-1 + AbstractTensorHandle* learning_rate = nullptr; + s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Train + int num_iters = 10; + std::vector mnist_outputs(3); + std::vector grads(2); + for (int i = 0; i < num_iters; i++) { + // Run Forward Pass + s = RunModel(MNISTGradModel, ctx.get(), + {X.get(), weights[0], weights[1], y.get()}, + absl::MakeSpan(mnist_outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Fill grads + grads[0] = mnist_outputs[0]; + grads[1] = mnist_outputs[1]; + + // Gradient Update + s = UpdateWeights(ctx.get(), grads, weights, learning_rate); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + } + + grads[0]->Unref(); // release W1_grad + grads[1]->Unref(); // release W2_grad + mnist_outputs[2]->Unref(); // release loss +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc new file mode 100644 index 00000000000..932605ab8e0 --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -0,0 +1,518 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/mnist_gradients_testutil.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +// ========================== Tape Ops ============================== + +namespace tensorflow { +namespace gradients { +namespace internal { + +using std::vector; +using tensorflow::tracing::TracingOperation; + +// Computes `inputs[0] + inputs[1]` and records it on the tape. +Status Add(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr add_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(add_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(add_op.get())->SetOpName("my_add")); + } + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); + int num_retvals = 1; + return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. +Status MatMul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b, + const GradientRegistry& registry) { + AbstractOperationPtr matmul_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(matmul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(matmul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op)); + TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( + matmul_op.get(), "transpose_a", transpose_a, &forward_op)); + TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( + matmul_op.get(), "transpose_b", transpose_b, &forward_op)); + + int num_retvals = 1; + return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +Status Mul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractOperationPtr mul_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(mul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(mul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op)); + + int num_retvals = 1; + return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `Relu(inputs[0])` and records it on the tape. +Status Relu(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractOperationPtr relu_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(relu_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(relu_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op)); + int num_retvals = 1; + return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not +// one-hot) and records it on the tape. +Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractTensorHandle* scores = inputs[0]; + AbstractTensorHandle* labels = inputs[1]; + + AbstractOperationPtr sm_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(sm_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sm_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op)); + TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op)); + + int num_retvals = 2; // returns loss values and backprop + return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +//===================== Test Models to run ========================= + +// Computes +// y = inputs[0] + inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status AddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector add_outputs(1); + TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), + registry)); // Compute x+y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto add_output : add_outputs) { + add_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +// Computes +// y = inputs[0] * inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status MatMulGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + vector mm_outputs(1); + TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute x*y. + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto mm_output : mm_outputs) { + mm_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +// Model to run 2-layer net +Status MNISTForwardModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + /** + * We will trace a 2-layer fully connected network for an MNIST model: + * + * def mnist_forward(X, W1, W2, y_labels): + * mm_out_1 = tf.matmul(X,W1) + * hidden_layer = tf.nn.relu(mm_out_1) + * scores = tf.matmul(hidden_layer,W2) + * softmax = + * tf.nn.sparse_softmax_cross_entropy_with_logits(scores, + * y_labels) + * return scores, softmax + * + * Use this convention for inputs: + * + * inputs = [X, W1, W2, y_labels] + * + */ + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + AbstractTensorHandle* W2 = inputs[2]; + AbstractTensorHandle* y_labels = inputs[3]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(W1)); // Watch W1. + tape->Watch(ToId(W2)); // Watch W2. + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute X*W1 + + TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, + absl::MakeSpan(temp_outputs), "relu", + registry)); // Compute Relu(X*W1) + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, + absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false, + registry)); // Compute W2*Relu(X*W1) + + AbstractTensorHandle* scores = temp_outputs[0]; + + temp_outputs.resize(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( + ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmax_loss", registry)); // Compute Softmax(Scores,labels) + + AbstractTensorHandle* loss_vals = temp_outputs[0]; + + outputs[0] = scores; + outputs[1] = loss_vals; + delete tape; + return Status::OK(); +} + +Status MatMulTransposeModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(X)); + tape->Watch(ToId(W1)); + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/true, + /*transpose_b=*/false, registry)); // Compute X*W1 + + outputs[0] = temp_outputs[0]; + + delete tape; + return Status::OK(); +} + +Status ReluGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch X + vector relu_outputs(1); + TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), + "relu0", registry)); // Relu(X) + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + for (auto relu_output : relu_outputs) { + relu_output->Unref(); + } + + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + +Status SoftmaxLossGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch scores. + tape->Watch(ToId(inputs[1])); // Watch labels. + vector sm_outputs(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( + ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +Status MNISTGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + AbstractTensorHandle* W2 = inputs[2]; + AbstractTensorHandle* y_labels = inputs[3]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/true); + tape->Watch(ToId(X)); // Watch X. + tape->Watch(ToId(W1)); // Watch W1. + tape->Watch(ToId(W2)); // Watch W1. + vector temp_outputs(1); + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute X*W1 + + AbstractTensorHandle* mm = temp_outputs[0]; + + TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, + absl::MakeSpan(temp_outputs), // Relu(X*W1) + "relu0", registry)); + + AbstractTensorHandle* hidden = temp_outputs[0]; + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, + absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false, + registry)); // W2*Relu(X*W1) + + AbstractTensorHandle* scores = temp_outputs[0]; + + temp_outputs.resize(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( + ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmaxloss", registry)); // W2*Relu(X*W1) + + AbstractTensorHandle* loss = temp_outputs[0]; + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR( + tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)}, + /*source_tensor_ids=*/{ToId(W1), ToId(W2)}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + // Only release 2nd temp output as first holds loss values. + temp_outputs[1]->Unref(); + + outputs[0] = out_grads[0]; // dW1 + outputs[1] = out_grads[1]; // dW2 + outputs[2] = loss; + + delete tape; + return Status::OK(); +} + +Status ScalarMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* eta = inputs[0]; + AbstractTensorHandle* A = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), + "scalarMul0", registry)); // Compute eta*A + + outputs[0] = temp_outputs[0]; + + delete tape; + return Status::OK(); +} + +Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(1); + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute X*W1 + + outputs[0] = temp_outputs[0]; + delete tape; + return Status::OK(); +} + +Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* x = inputs[0]; + AbstractTensorHandle* y = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(1); + TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), + "mul0", registry)); // Compute x*y + + outputs[0] = temp_outputs[0]; + delete tape; + return Status::OK(); +} + +Status SoftmaxModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* x = inputs[0]; + AbstractTensorHandle* labels = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + std::vector temp_outputs(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( + ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", + registry)); + + outputs[0] = temp_outputs[0]; // loss values + + delete tape; + return Status::OK(); +} + +// ============================= End Models ================================ + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h new file mode 100644 index 00000000000..1cf87bb9dee --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" + +// ========================== Tape Ops ============================== + +namespace tensorflow { +namespace gradients { +namespace internal { +// Computes `inputs[0] + inputs[1]` and records it on the tape. +Status Add(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. +Status MatMul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b, + const GradientRegistry& registry); + +// Computes `inputs[0] * inputs[1]` and records it on the tape. +Status Mul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// Computes `Relu(inputs[0])` and records it on the tape. +Status Relu(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the +// tape. +Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// ====================== End Tape Ops ============================ + +// Computes +// y = inputs[0] + inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status AddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes +// y = inputs[0] * inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status MatMulGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes 2-layer Neural Network with Softmax Loss. +Status MNISTForwardModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes MatMul with first matrix tranposed. +Status MatMulTransposeModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify ReluGrad functionality +Status ReluGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify SoftmaxGrad functionality +Status SoftmaxLossGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify Multi-grad functionality for MNIST +Status MNISTGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify scalar-tensor multiplication Op +Status ScalarMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +Status SoftmaxModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 0d0e5ffce10..3eec95294b3 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -76,10 +76,26 @@ cc_library( "//tensorflow/c/eager:c_api_experimental", "//tensorflow/core:lib", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) +tf_cc_test( + name = "parallel_device_lib_test", + srcs = ["parallel_device_lib_test.cc"], + deps = [ + ":parallel_device_lib", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "parallel_device_testlib", testonly = 1, @@ -87,7 +103,6 @@ cc_library( hdrs = ["parallel_device_testlib.h"], deps = [ ":parallel_device", - ":parallel_device_ops", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -102,7 +117,6 @@ tf_cc_test( srcs = ["parallel_device_test.cc"], deps = [ ":parallel_device", - ":parallel_device_ops", ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -122,7 +136,6 @@ tf_cc_test( args = ["--heap_check=local"], deps = [ ":parallel_device", - ":parallel_device_ops", ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -134,19 +147,3 @@ tf_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) - -# Note: ParallelDevice-specific ops are experimental and not currently linked in -# to TensorFlow by default, just used in a few tests. -filegroup( - name = "parallel_device_ops_srcs", - srcs = ["parallel_device_ops.cc"], - visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], -) - -cc_library( - name = "parallel_device_ops", - srcs = [":parallel_device_ops_srcs"], - visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, -) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index d0e9f351478..41bde23448b 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -136,13 +136,6 @@ absl::optional> ExecuteWithSpecialOps( } result.emplace(std::move(outputs)); return result; - } else if (operation_name == std::string("DeviceID")) { - std::vector result_content; - result_content.reserve(1); - result_content.push_back(parallel_device.DeviceIDs(context, status)); - if (TF_GetCode(status) != TF_OK) return result; - result.emplace(std::move(result_content)); - return result; } std::vector parallel_inputs; std::vector> implicitly_broadcast_tensors; @@ -255,28 +248,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context, // Since this function is used to satisfy the TFE_CustomDevice C API, // device_info is passed in using a C-style generic. It must always be a // ParallelDevice. -void ParallelDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, - const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, status); + if (*requested_placement == '\0') { + TF_SetStatus( + status, TF_INTERNAL, + "Ops must be placed on the parallel device explicitly, or their inputs " + "first un-packed. Got an un-placed op with an input placed on the " + "parallel device."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + NamedParallelDevice* named_device = reinterpret_cast(device_info); std::vector typed_inputs; + int num_inputs = TFE_OpGetFlatInputCount(original_op, status); + if (TF_GetCode(status) != TF_OK) return; typed_inputs.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status); + if (TF_GetCode(status) != TF_OK) return; const char* tensor_handle_device = - TFE_TensorHandleDeviceName(inputs[i], status); + TFE_TensorHandleDeviceName(input, status); if (TF_GetCode(status) != TF_OK) return; if (named_device->name() == tensor_handle_device) { // We assume that any tensors already placed on this device are // ParallelTensors. typed_inputs.emplace_back(reinterpret_cast( - TFE_TensorHandleDevicePointer(inputs[i], status))); + TFE_TensorHandleDevicePointer(input, status))); if (TF_GetCode(status) != TF_OK) return; } else { - typed_inputs.emplace_back(inputs[i]); + typed_inputs.emplace_back(input); } } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 768f686bd88..e270bfcbb80 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -118,6 +119,9 @@ class DeviceThread { int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_); // Outputs std::vector op_outputs_ TF_GUARDED_BY(execution_mutex_); + // TF_Status is an incomplete type and so can't be stack allocated. To avoid + // unnecessary allocations each Execute call, we keep one heap-allocated + // version for the thread. StatusPtr status_ TF_GUARDED_BY(execution_mutex_); const std::string device_; @@ -188,6 +192,9 @@ std::vector DeviceThread::Join(TF_Status* status) { if (TF_GetCode(status_.get()) != TF_OK) { TF_SetStatus(status, TF_GetCode(status_.get()), TF_Message(status_.get())); + // Reset the member `status_` so future op executions (after recovery from + // the bad `status`) start with an OK status. + TF_SetStatus(status_.get(), TF_OK, ""); } execution_state_ = ExecutionState::kIdle; result = std::move(op_outputs_); @@ -255,18 +262,27 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } -std::unique_ptr ParallelDevice::DeviceIDs( - TFE_Context* context, TF_Status* status) const { +std::unique_ptr ParallelDevice::Vector( + TFE_Context* context, TF_Status* status, + absl::Span values) const { // TODO(allenl): We could cache DeviceIDs (keyed by context). std::vector components; components.reserve(underlying_devices_.size()); - for (int device_index = 0; device_index < underlying_devices_.size(); + + if (values.size() != num_underlying_devices()) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Number of values did not match number of underlying devices."); + return nullptr; + } + + for (int device_index = 0; device_index < num_underlying_devices(); ++device_index) { - int32_t* device_id = new int32_t; - *device_id = device_index; + int32_t* device_value = new int32_t; + *device_value = values[device_index]; std::unique_ptr tensor( TF_NewTensor( - TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id, + TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value, sizeof(int32_t), [](void* data, size_t, void* arg) { delete reinterpret_cast(data); @@ -295,6 +311,16 @@ std::unique_ptr ParallelDevice::DeviceIDs( status); } +std::unique_ptr ParallelDevice::DeviceIDs( + TFE_Context* context, TF_Status* status) const { + std::vector ids; + ids.reserve(num_underlying_devices()); + for (int i = 0; i < num_underlying_devices(); ++i) { + ids.push_back(i); + } + return Vector(context, status, ids); +} + absl::optional>> ParallelDevice::Execute(TFE_Context* context, const std::vector& inputs, @@ -319,21 +345,36 @@ ParallelDevice::Execute(TFE_Context* context, std::move(device_inputs), attributes, expected_max_outputs); } + StatusPtr first_bad_status(nullptr); for (int device_index = 0; device_index < underlying_devices_.size(); ++device_index) { DeviceThread* device_thread = device_threads_[device_index].get(); per_device_output_tensors.push_back(device_thread->Join(status)); - if (TF_GetCode(status) != TF_OK) return result; + // We will run every Join even if there are bad statuses in case the user + // wants to recover and continue running ops on the parallel device (which + // would otherwise deadlock). + if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) { + first_bad_status.reset(TF_NewStatus()); + TF_SetStatus(first_bad_status.get(), TF_GetCode(status), + TF_Message(status)); + } + if (device_index == 0) { first_op_output_count = per_device_output_tensors.rbegin()->size(); } else { - if (per_device_output_tensors.rbegin()->size() != first_op_output_count) { - TF_SetStatus(status, TF_INTERNAL, + if (first_bad_status == nullptr && + per_device_output_tensors.rbegin()->size() != first_op_output_count) { + first_bad_status.reset(TF_NewStatus()); + TF_SetStatus(first_bad_status.get(), TF_INTERNAL, "Parallel ops produced different numbers of tensors."); - return result; } } } + if (first_bad_status != nullptr) { + TF_SetStatus(status, TF_GetCode(first_bad_status.get()), + TF_Message(first_bad_status.get())); + return result; + } // For each output of the original operation, pack the per-device // TensorHandles we've computed into a single parallel TensorHandle. std::vector> per_device_outputs; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index cbfea31d95f..b3dc47ab088 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "absl/types/span.h" #include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" @@ -61,6 +62,11 @@ class ParallelDevice { TFE_TensorHandle* tensor, TF_Status* status) const; + // Construct a parallel tensor consisting of the scalar values from `values`. + std::unique_ptr Vector( + TFE_Context* context, TF_Status* status, + absl::Span values) const; + // A parallel tensor with scalar integers numbering component devices. std::unique_ptr DeviceIDs(TFE_Context* context, TF_Status* status) const; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc new file mode 100644 index 00000000000..35befe959cb --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace parallel_device { + +TEST(PARALLEL_DEVICE_LIB, TestOpWithError) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + std::unique_ptr handle_op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + auto outputs = + parallel_device.Execute(context.get(), std::vector(), + "VarHandleOp", TFE_OpGetAttrs(handle_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + const std::vector>& handles = *outputs; + std::vector handle_inputs; + handle_inputs.reserve(handles.size()); + for (auto& handle : handles) { + handle_inputs.push_back(handle.get()); + } + std::unique_ptr read_op( + TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT); + parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp", + TFE_OpGetAttrs(read_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); + TF_SetStatus(status.get(), TF_OK, ""); + + // Check that ops still run successfully on the device. + parallel_device.Execute(context.get(), std::vector(), + "VarHandleOp", TFE_OpGetAttrs(handle_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + +} // namespace parallel_device +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc index 828dcbae093..67bc596b180 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); ASSERT_EQ(underlying_devices[1], second_device); } - // Compute the device ID twice and verify the result - for (int i = 0; i < 2; ++i) { - std::unique_ptr op( - TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - TFE_OpSetDevice(op.get(), device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - std::array components; - ExtractPerDeviceValues(context, result_handle, &components, status.get()); - TFE_DeleteTensorHandle(result_handle); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - ExpectScalarEq(components[0].get(), 0); - ExpectScalarEq(components[1].get(), 1); - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } } diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 27629bb3bdf..fcebe973500 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -146,13 +146,16 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. + // When running backward functions, builds zeros-like tensors for + // incoming grads which are nullptrs, unless `build_default_zeros_grads` + // is set to false. Status ComputeGradient( const VSpace& vspace, const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, - std::vector* result); + std::vector* result, bool build_default_zeros_grads = true); bool IsPersistent() const { return persistent_; } @@ -655,8 +658,8 @@ Status GradientTape::ComputeGradient( const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, - gtl::ArraySlice output_gradients, - std::vector* result) { + gtl::ArraySlice output_gradients, std::vector* result, + bool build_default_zeros_grads) { std::unordered_set sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( @@ -717,14 +720,14 @@ Status GradientTape::ComputeGradient( const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { - auto func_name_it = - FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); - if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && - func_name_it->second.find(i) != func_name_it->second.end()) { - out_gradients.push_back(nullptr); - } else { - out_gradients.push_back(nullptr); - zero_indices.push_back(i); + out_gradients.push_back(nullptr); + if (build_default_zeros_grads) { + auto func_name_it = + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() || + func_name_it->second.find(i) == func_name_it->second.end()) { + zero_indices.push_back(i); + } } } else { any_gradient_nonzero = true; @@ -745,6 +748,7 @@ Status GradientTape::ComputeGradient( } } std::vector in_gradients; + DCHECK(build_default_zeros_grads || zero_indices.empty()); if (any_gradient_nonzero) { for (const auto i : zero_indices) { out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index 00a587521fd..9c8d3518800 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -35,8 +35,8 @@ using UniquePtrTo_TF_Status = ::std::unique_ptr; Status ModularFileSystem::NewRandomAccessFile( - const std::string& fname, - std::unique_ptr* result /*, TransactionToken* token */) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_random_access_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewRandomAccessFile()")); @@ -55,8 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile( } Status ModularFileSystem::NewWritableFile( - const std::string& fname, - std::unique_ptr* result /*, TransactionToken* token */) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_writable_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewWritableFile()")); @@ -75,8 +75,8 @@ Status ModularFileSystem::NewWritableFile( } Status ModularFileSystem::NewAppendableFile( - const std::string& fname, - std::unique_ptr* result /*, TransactionToken* token */) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_appendable_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewAppendableFile()")); @@ -95,8 +95,8 @@ Status ModularFileSystem::NewAppendableFile( } Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* - result /*, TransactionToken* token */) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_read_only_memory_region_from_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, @@ -116,8 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::FileExists( - const std::string& fname /*, TransactionToken* token */) { +Status ModularFileSystem::FileExists(const std::string& fname, + TransactionToken* token) { if (ops_->path_exists == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support FileExists()")); @@ -129,9 +129,9 @@ Status ModularFileSystem::FileExists( return StatusFromTF_Status(plugin_status.get()); } -bool ModularFileSystem::FilesExist( - const std::vector& files, - std::vector* status /*, TransactionToken* token */) { +bool ModularFileSystem::FilesExist(const std::vector& files, + TransactionToken* token, + std::vector* status) { if (ops_->paths_exist == nullptr) return FileSystem::FilesExist(files, status); @@ -162,9 +162,9 @@ bool ModularFileSystem::FilesExist( return result; } -Status ModularFileSystem::GetChildren( - const std::string& dir, - std::vector* result /*, TransactionToken* token */) { +Status ModularFileSystem::GetChildren(const std::string& dir, + TransactionToken* token, + std::vector* result) { if (ops_->get_children == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", dir, " does not support GetChildren()")); @@ -188,9 +188,9 @@ Status ModularFileSystem::GetChildren( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::GetMatchingPaths( - const std::string& pattern, - std::vector* result /*, TransactionToken* token */) { +Status ModularFileSystem::GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* result) { if (ops_->get_matching_paths == nullptr) return internal::GetMatchingPaths(this, Env::Default(), pattern, result); @@ -211,8 +211,8 @@ Status ModularFileSystem::GetMatchingPaths( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::DeleteFile( - const std::string& fname /*, TransactionToken* token */) { +Status ModularFileSystem::DeleteFile(const std::string& fname, + TransactionToken* token) { if (ops_->delete_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support DeleteFile()")); @@ -224,9 +224,10 @@ Status ModularFileSystem::DeleteFile( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::DeleteRecursively( - const std::string& dirname, int64* undeleted_files, - int64* undeleted_dirs /*, TransactionToken* token */) { +Status ModularFileSystem::DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64* undeleted_files, + int64* undeleted_dirs) { if (undeleted_files == nullptr || undeleted_dirs == nullptr) return errors::FailedPrecondition( "DeleteRecursively must not be called with `undeleted_files` or " @@ -247,8 +248,8 @@ Status ModularFileSystem::DeleteRecursively( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::DeleteDir( - const std::string& dirname /*, TransactionToken* token */) { +Status ModularFileSystem::DeleteDir(const std::string& dirname, + TransactionToken* token) { if (ops_->delete_dir == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", dirname, " does not support DeleteDir()")); @@ -260,8 +261,8 @@ Status ModularFileSystem::DeleteDir( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::RecursivelyCreateDir( - const std::string& dirname /*, TransactionToken* token */) { +Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) { if (ops_->recursively_create_dir == nullptr) return FileSystem::RecursivelyCreateDir(dirname); @@ -272,8 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::CreateDir( - const std::string& dirname /*, TransactionToken* token */) { +Status ModularFileSystem::CreateDir(const std::string& dirname, + TransactionToken* token) { if (ops_->create_dir == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", dirname, " does not support CreateDir()")); @@ -285,9 +286,8 @@ Status ModularFileSystem::CreateDir( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::Stat( - const std::string& fname, - FileStatistics* stat /*, TransactionToken* token */) { +Status ModularFileSystem::Stat(const std::string& fname, + TransactionToken* token, FileStatistics* stat) { if (ops_->stat == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support Stat()")); @@ -310,8 +310,8 @@ Status ModularFileSystem::Stat( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::IsDirectory( - const std::string& name /*, TransactionToken* token */) { +Status ModularFileSystem::IsDirectory(const std::string& name, + TransactionToken* token) { if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); @@ -321,9 +321,9 @@ Status ModularFileSystem::IsDirectory( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::GetFileSize( - const std::string& fname, - uint64* file_size /*, TransactionToken* token */) { +Status ModularFileSystem::GetFileSize(const std::string& fname, + TransactionToken* token, + uint64* file_size) { if (ops_->get_file_size == nullptr) { FileStatistics stat; Status status = Stat(fname, &stat); @@ -342,9 +342,9 @@ Status ModularFileSystem::GetFileSize( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::RenameFile( - const std::string& src, - const std::string& target /*, TransactionToken* token */) { +Status ModularFileSystem::RenameFile(const std::string& src, + const std::string& target, + TransactionToken* token) { if (ops_->rename_file == nullptr) { Status status = CopyFile(src, target); if (status.ok()) status = DeleteFile(src); @@ -359,9 +359,9 @@ Status ModularFileSystem::RenameFile( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::CopyFile( - const std::string& src, - const std::string& target /*, TransactionToken* token */) { +Status ModularFileSystem::CopyFile(const std::string& src, + const std::string& target, + TransactionToken* token) { if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); @@ -372,8 +372,7 @@ Status ModularFileSystem::CopyFile( return StatusFromTF_Status(plugin_status.get()); } -std::string ModularFileSystem::TranslateName( - const std::string& name /*, TransactionToken* token */) const { +std::string ModularFileSystem::TranslateName(const std::string& name) const { if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name); char* p = ops_->translate_name(filesystem_.get(), name.c_str()); @@ -385,7 +384,7 @@ std::string ModularFileSystem::TranslateName( return ret; } -void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) { +void ModularFileSystem::FlushCaches(TransactionToken* token) { if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get()); } diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.h b/tensorflow/c/experimental/filesystem/modular_filesystem.h index a2639152eff..061a1aa446b 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -59,71 +59,48 @@ class ModularFileSystem final : public FileSystem { ~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); } + TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; + Status NewRandomAccessFile( - const std::string& fname, - std::unique_ptr* - result /*, TransactionToken* token = nullptr */) override; - Status NewWritableFile( - const std::string& fname, - std::unique_ptr* - result /*, TransactionToken* token = nullptr */) override; - Status NewAppendableFile( - const std::string& fname, - std::unique_ptr* - result /*, TransactionToken* token = nullptr */) override; + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + Status NewWritableFile(const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + Status NewAppendableFile(const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, - std::unique_ptr* - result /*, TransactionToken* token = nullptr */) override; - Status FileExists( - const std::string& fname /*, TransactionToken* token = nullptr */) - override; + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + Status FileExists(const std::string& fname, TransactionToken* token) override; bool FilesExist(const std::vector& files, - std::vector* - status /*, TransactionToken* token = nullptr */) override; - Status GetChildren( - const std::string& dir, - std::vector* result /*, TransactionToken* token = nullptr */) - override; - Status GetMatchingPaths( - const std::string& pattern, - std::vector* - results /*, TransactionToken* token = nullptr */) override; - Status DeleteFile( - const std::string& fname /*, TransactionToken* token = nullptr */) - override; - Status DeleteRecursively( - const std::string& dirname, int64* undeleted_files, - int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override; - Status DeleteDir( - const std::string& dirname /*, TransactionToken* token = nullptr */) - override; - Status RecursivelyCreateDir( - const std::string& dirname /*, TransactionToken* token = nullptr */) - override; - Status CreateDir( - const std::string& dirname /*, TransactionToken* token = nullptr */) - override; - Status Stat( - const std::string& fname, - FileStatistics* stat /*, TransactionToken* token = nullptr */) override; - Status IsDirectory( - const std::string& fname /*, TransactionToken* token = nullptr */) - override; - Status GetFileSize( - const std::string& fname, - uint64* file_size /*, TransactionToken* token = nullptr */) override; - Status RenameFile( - const std::string& src, - const std::string& target /*, TransactionToken* token = nullptr */) - override; - Status CopyFile(const std::string& src, - const std::string& - target /*, TransactionToken* token = nullptr */) override; - std::string TranslateName( - const std::string& name /*, TransactionToken* token = nullptr */) - const override; - void FlushCaches(/* TransactionToken* token=nullptr */) override; + TransactionToken* token, + std::vector* status) override; + Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) override; + Status GetMatchingPaths(const std::string& pattern, TransactionToken* token, + std::vector* results) override; + Status DeleteFile(const std::string& fname, TransactionToken* token) override; + Status DeleteRecursively(const std::string& dirname, TransactionToken* token, + int64* undeleted_files, + int64* undeleted_dirs) override; + Status DeleteDir(const std::string& dirname, + TransactionToken* token) override; + Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) override; + Status CreateDir(const std::string& dirname, + TransactionToken* token) override; + Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) override; + Status IsDirectory(const std::string& fname, + TransactionToken* token) override; + Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) override; + Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + std::string TranslateName(const std::string& name) const override; + void FlushCaches(TransactionToken* token) override; private: std::unique_ptr filesystem_; diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index b2636571c25..54217db1de0 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -29,10 +29,12 @@ cc_library( ":gcs_helper", ":ram_file_block_cache", "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", ], ) @@ -58,6 +60,7 @@ cc_library( deps = [ ":cleanup", "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index b6b481cda66..8cd8ad7ca81 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -19,9 +19,11 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "absl/types/variant.h" #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for GCS environments. @@ -119,20 +121,20 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, return -1; } int64_t read; - if (!absl::SimpleAtoi(stream.headers().find("content-length")->second, - &read)) { + auto content_length = stream.headers().find("content-length"); + if (content_length == stream.headers().end()) { // When we read a file with offset that is bigger than the actual file size. // GCS will return an empty header (e.g no `content-length` header). In this // case, we will set read to `0` and continue. - if (TF_GetCode(status) == TF_OUT_OF_RANGE) { - read = 0; - } else { - TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); - return -1; - } + read = 0; + } else if (!absl::SimpleAtoi(content_length->second, &read)) { + TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); + return -1; } // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here. TF_SetStatus(status, TF_OK, ""); + TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset, + read); stream.read(buffer, read); read = stream.gcount(); if (read < buffer_size) { @@ -145,6 +147,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, path, " @ ", offset) .c_str()); } + TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(), + offset); } } return read; @@ -258,7 +262,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object, if (*offset == -1 || *offset == 0) { // UploadFile will automatically switch to resumable upload based on Client // configuration. - auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object); + auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -277,15 +282,18 @@ static void SyncImpl(const std::string& bucket, const std::string& object, } else { std::string temporary_object = gcs::CreateRandomPrefixName("tf_writable_file_gcs"); - auto metadata = - gcs_client->UploadFile(outfile->getName(), bucket, temporary_object); + auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, + temporary_object, gcs::Fields("")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; } + TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(), + temporary_object.c_str(), bucket.c_str(), object.c_str()); const std::vector source_objects = { {object, {}, {}}, {temporary_object, {}, {}}}; - metadata = gcs_client->ComposeObject(bucket, source_objects, object); + metadata = gcs_client->ComposeObject(bucket, source_objects, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -320,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n, "The internal temporary file is not writable."); return; } + TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(), + gcs_file->object.c_str(), n); gcs_file->sync_need = true; gcs_file->outfile.write(buffer, n); if (!gcs_file->outfile) @@ -345,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) { void Flush(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); if (gcs_file->sync_need) { + TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (!gcs_file->outfile) { TF_SetStatus(status, TF_INTERNAL, "Could not append to the internal temporary file."); @@ -352,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset, &gcs_file->outfile, gcs_file->gcs_client, status); + TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (TF_GetCode(status) != TF_OK) return; gcs_file->sync_need = false; } else { @@ -360,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } void Sync(const TF_WritableFile* file, TF_Status* status) { + auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); Flush(file, status); } void Close(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (gcs_file->sync_need) { Flush(file, status); } @@ -427,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client) if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) { max_staleness = value; } + TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u", + max_bytes, block_size, max_staleness); file_block_cache = std::make_unique( block_size, max_bytes, max_staleness, @@ -503,13 +524,18 @@ void Cleanup(TF_Filesystem* filesystem) { static void UncachedStatForObject(const std::string& bucket, const std::string& object, GcsFileStat* stat, gcs::Client* gcs_client, TF_Status* status) { - auto metadata = gcs_client->GetObjectMetadata(bucket, object); + auto metadata = gcs_client->GetObjectMetadata( + bucket, object, gcs::Fields("generation,size,timeStorageClassUpdated")); if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status); stat->generation_number = metadata->generation(); stat->base.length = metadata->size(); stat->base.mtime_nsec = metadata->time_storage_class_updated().time_since_epoch().count(); stat->base.is_directory = object.back() == '/'; + TF_VLog(1, + "Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;", + bucket.c_str(), object.c_str(), stat->base.length, + stat->generation_number, stat->base.mtime_nsec); return TF_SetStatus(status, TF_OK, ""); } @@ -544,9 +570,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, if (TF_GetCode(status) != TF_OK) return -1; if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature( path, stat.generation_number)) { - std::cout - << "File signature has been changed. Refreshing the cache. Path: " - << path; + TF_VLog( + 1, + "File signature has been changed. Refreshing the cache. Path: %s", + path.c_str()); } read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status); } else { @@ -578,6 +605,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, (gcs_file->compose ? 0 : -1)}); // We are responsible for freeing the pointer returned by TF_GetTempFileName free(temp_file_name); + TF_VLog(3, "GcsWritableFile: %s", path); TF_SetStatus(status, TF_OK, ""); } @@ -607,7 +635,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, } else { // If compose is true, we do not download anything. // Instead we only check if this file exists on server or not. - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object, + gcs::Fields("size")); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) == TF_OK) { file->plugin_file = new tf_writable_file::GCSFile( @@ -623,7 +652,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, return; } } - + TF_VLog(3, "GcsWritableFile: %s with existing file %s", path, + temp_file_name.c_str()); TF_SetStatus(status, TF_OK, ""); } @@ -638,7 +668,8 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, if (TF_GetCode(status) != TF_OK) return; auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object, + gcs::Fields("size")); if (!metadata) { TF_SetStatusFromGCSStatus(metadata.status(), status); return; @@ -663,28 +694,190 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, } } -void CreateDir(const TF_Filesystem* filesystem, const char* path, - TF_Status* status) { +static void StatForObject(GCSFile* gcs_file, const std::string& path, + const std::string& bucket, const std::string& object, + GcsFileStat* stat, TF_Status* status) { + if (object.empty()) + return TF_SetStatus( + status, TF_INVALID_ARGUMENT, + absl::StrCat("'object' must be a non-empty string. (File: ", path, ")") + .c_str()); + TF_SetStatus(status, TF_OK, ""); + gcs_file->stat_cache->LookupOrCompute( + path, stat, + [gcs_file, bucket, object](const std::string& path, GcsFileStat* stat, + TF_Status* status) { + UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client, + status); + }, + status); +} + +static bool ObjectExists(GCSFile* gcs_file, const std::string& path, + const std::string& bucket, const std::string& object, + TF_Status* status) { + GcsFileStat stat; + StatForObject(gcs_file, path, bucket, object, &stat, status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) + return false; + if (TF_GetCode(status) == TF_NOT_FOUND) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return !stat.base.is_directory; +} + +static bool BucketExists(GCSFile* gcs_file, const std::string& bucket, + TF_Status* status) { + auto metadata = + gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields("")); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) + return false; + if (TF_GetCode(status) == TF_NOT_FOUND) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return true; +} + +static std::vector GetChildrenBounded( + GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive, + bool include_self_directory_marker, TF_Status* status) { + std::string bucket, prefix; + MaybeAppendSlash(&dir); + ParseGCSPath(dir, true, &bucket, &prefix, status); + + std::vector result; + uint64_t count = 0; + std::string delimiter = recursive ? "" : "/"; + + for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes( + bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter), + gcs::Fields("items(name),prefixes"))) { + if (count == max_results) { + TF_SetStatus(status, TF_OK, ""); + return result; + } + if (!item) { + TF_SetStatusFromGCSStatus(item.status(), status); + return result; + } + auto value = *std::move(item); + std::string children = absl::holds_alternative(value) + ? absl::get(value) + : absl::get(value).name(); + auto pos = children.find(prefix); + if (pos != 0) { + TF_SetStatus(status, TF_INTERNAL, + absl::StrCat("Unexpected response: the returned file name ", + children, " doesn't match the prefix ", prefix) + .c_str()); + return result; + } + children.erase(0, prefix.length()); + if (!children.empty() || include_self_directory_marker) { + result.emplace_back(children); + } + ++count; + } + + return result; +} + +static bool FolderExists(GCSFile* gcs_file, std::string dir, + TF_Status* status) { + ExpiringLRUCache::ComputeFunc compute_func = + [gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) { + auto children = + GetChildrenBounded(gcs_file, dir, 1, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + if (!children.empty()) { + stat->base = {0, 0, true}; + return TF_SetStatus(status, TF_OK, ""); + } else { + return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!"); + } + }; + GcsFileStat stat; + MaybeAppendSlash(&dir); + gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT) + return false; + if (TF_GetCode(status) == TF_INVALID_ARGUMENT) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return true; +} + +static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) { + absl::ReaderMutexLock l(&gcs_file->block_cache_lock); + gcs_file->file_block_cache->RemoveFile(path); + gcs_file->stat_cache->Delete(path); +} + +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { std::string bucket, object; ParseGCSPath(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); if (object.empty()) { - auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); - TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); + bool result = BucketExists(gcs_file, bucket, status); + if (result) return TF_SetStatus(status, TF_OK, ""); + } + + GcsFileStat stat; + StatForObject(gcs_file, path, bucket, object, &stat, status); + if (TF_GetCode(status) != TF_NOT_FOUND) return; + + bool result = FolderExists(gcs_file, path, status); + if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result)) + return; + return TF_SetStatus( + status, TF_NOT_FOUND, + absl::StrCat("The path ", path, " does not exist.").c_str()); +} + +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string dir = path; + MaybeAppendSlash(&dir); + TF_VLog(3, + "CreateDir: creating directory with path: %s and " + "path_with_slash: %s", + path, dir.c_str()); + std::string bucket, object; + ParseGCSPath(dir, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); + if (object.empty()) { + bool is_directory = BucketExists(gcs_file, bucket, status); + if (TF_GetCode(status) != TF_OK) return; + if (!is_directory) + TF_SetStatus(status, TF_NOT_FOUND, + absl::StrCat("The specified bucket ", dir, " was not found.") + .c_str()); return; } - MaybeAppendSlash(&object); - auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); - TF_SetStatusFromGCSStatus(object_metadata.status(), status); - if (TF_GetCode(status) == TF_NOT_FOUND) { - auto insert_metadata = - gcs_file->gcs_client.InsertObject(bucket, object, ""); - TF_SetStatusFromGCSStatus(insert_metadata.status(), status); - } else if (TF_GetCode(status) == TF_OK) { - TF_SetStatus(status, TF_ALREADY_EXISTS, path); + PathExists(filesystem, dir.c_str(), status); + if (TF_GetCode(status) == TF_OK) { + // Use the original name for a correct error here. + TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path); + return TF_SetStatus(status, TF_ALREADY_EXISTS, path); } + + auto metadata = gcs_file->gcs_client.InsertObject( + bucket, object, "", + // Adding this parameter means HTTP_CODE_PRECONDITION_FAILED + // will be returned if the object already exists, so avoid reuploading. + gcs::IfGenerationMatch(0), gcs::Fields("")); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) + TF_SetStatus(status, TF_ALREADY_EXISTS, path); } // TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the @@ -700,79 +893,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, auto gcs_file = static_cast(filesystem->plugin_filesystem); auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); TF_SetStatusFromGCSStatus(gcs_status, status); + if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path); } +// Checks that the directory is empty (i.e no objects with this prefix exist). +// Deletes the GCS directory marker if it exists. void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - std::string bucket, object; - ParseGCSPath(path, false, &bucket, &object, status); - if (TF_GetCode(status) != TF_OK) return; - MaybeAppendSlash(&object); + // A directory is considered empty either if there are no matching objects + // with the corresponding name prefix or if there is exactly one matching + // object and it is the directory marker. Therefore we need to retrieve + // at most two children for the prefix to detect if a directory is empty. auto gcs_file = static_cast(filesystem->plugin_filesystem); - int object_count = 0; - for (auto&& metadata : - gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) { - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); - return; - } - ++object_count; - // We consider a path is a non-empty directory in two cases: - // - There are more than two objects whose keys start with the name of this - // directory. - // - There is one object whose key contains the name of this directory ( but - // not equal ). - if (object_count > 1 || metadata->name() != object) { - TF_SetStatus(status, TF_FAILED_PRECONDITION, - "Cannot delete a non-empty directory."); - return; - } - } - auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); - TF_SetStatusFromGCSStatus(gcs_status, status); -} - -// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be -// some differents compared to the default implementation. Will be refactored. -static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path, - uint64_t* undeleted_files, - uint64_t* undeleted_dirs, TF_Status* status) { - std::string bucket, object; - ParseGCSPath(path, false, &bucket, &object, status); + auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status); if (TF_GetCode(status) != TF_OK) return; - - auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object); - TF_SetStatusFromGCSStatus(gcs_status, status); - if (TF_GetCode(status) != TF_OK) return; - *undeleted_dirs = 0; - *undeleted_files = 0; -} - -// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND` -// if the object does not exist. In that case, we will have to check if the -// `src` is a directory or not to set the correspondent `status` (i.e -// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if -// path `src` is a directory). -void RenameFile(const TF_Filesystem* filesystem, const char* src, - const char* dst, TF_Status* status) { - std::string bucket_src, object_src; - ParseGCSPath(src, false, &bucket_src, &object_src, status); - if (TF_GetCode(status) != TF_OK) return; - - std::string bucket_dst, object_dst; - ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); - if (TF_GetCode(status) != TF_OK) return; - - auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( - bucket_src, object_src, bucket_dst, object_dst); - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); + if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty())) + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Cannot delete a non-empty directory."); + if (childrens.size() == 1 && childrens[0].empty()) { + // This is the directory marker object. Delete it. + std::string dir = path; + MaybeAppendSlash(&dir); + DeleteFile(filesystem, dir.c_str(), status); return; } - auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src); - TF_SetStatusFromGCSStatus(gcs_status, status); + TF_SetStatus(status, TF_OK, ""); } void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, @@ -787,35 +932,11 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, auto gcs_file = static_cast(filesystem->plugin_filesystem); auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( - bucket_src, object_src, bucket_dst, object_dst); + bucket_src, object_src, bucket_dst, object_dst, + gcs::Fields("done,rewriteToken")); TF_SetStatusFromGCSStatus(metadata.status(), status); } -// TODO(vnvo2409): This approach can cause a problem when our path is -// `path/to/dir` and there is an object with key `path/to/directory`. Will be -// fixed when refactoring. -void PathExists(const TF_Filesystem* filesystem, const char* path, - TF_Status* status) { - std::string bucket, object; - ParseGCSPath(path, true, &bucket, &object, status); - if (TF_GetCode(status) != TF_OK) return; - - auto gcs_file = static_cast(filesystem->plugin_filesystem); - for (auto&& metadata : - gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) { - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); - return; - } - // We consider a path exists if there is at least one object whose key - // contains the path. - return TF_SetStatus(status, TF_OK, ""); - } - return TF_SetStatus( - status, TF_NOT_FOUND, - absl::StrCat("The path ", path, " does not exist.").c_str()); -} - bool IsDirectory(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { std::string bucket, object; @@ -824,41 +945,133 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path, auto gcs_file = static_cast(filesystem->plugin_filesystem); if (object.empty()) { - auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); - TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); - if (TF_GetCode(status) == TF_OK) - return true; - else - return false; + bool result = BucketExists(gcs_file, bucket, status); + if (TF_GetCode(status) != TF_OK) return false; + if (!result) + TF_SetStatus( + status, TF_NOT_FOUND, + absl::StrCat("The specified bucket gs://", bucket, " was not found.") + .c_str()); + return result; } - // We check if there is an object with this key on the GCS server. - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); - if (metadata) { - TF_SetStatus(status, TF_OK, ""); - if (metadata->name().back() == '/') - return true; - else - return false; - } + bool is_folder = FolderExists(gcs_file, path, status); + if (TF_GetCode(status) != TF_OK) return false; + if (is_folder) return true; - // If there is no object with this key on the GCS server. We check if there is - // any object whose key contains that path. - MaybeAppendSlash(&object); - for (auto&& metadata : - gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) { - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); - return false; - } - TF_SetStatus(status, TF_OK, ""); - return true; + bool is_object = ObjectExists(gcs_file, path, bucket, object, status); + if (TF_GetCode(status) != TF_OK) return false; + if (is_object) { + TF_SetStatus( + status, TF_FAILED_PRECONDITION, + absl::StrCat("The specified path ", path, " is not a directory.") + .c_str()); + return false; } TF_SetStatus(status, TF_NOT_FOUND, absl::StrCat("The path ", path, " does not exist.").c_str()); return false; } +static void RenameObject(const TF_Filesystem* filesystem, + const std::string& src, const std::string& dst, + TF_Status* status) { + TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str()); + std::string bucket_src, object_src; + ParseGCSPath(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string bucket_dst, object_dst; + ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( + bucket_src, object_src, bucket_dst, object_dst, + gcs::Fields("done,rewriteToken")); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) != TF_OK) return; + TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str()); + + ClearFileCaches(gcs_file, dst); + DeleteFile(filesystem, src.c_str(), status); +} + +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + if (!IsDirectory(filesystem, src, status)) { + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) { + TF_SetStatus(status, TF_OK, ""); + RenameObject(filesystem, src, dst, status); + } + return; + } + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string src_dir = src; + std::string dst_dir = dst; + MaybeAppendSlash(&src_dir); + MaybeAppendSlash(&dst_dir); + for (const std::string& children : childrens) { + RenameObject(filesystem, src_dir + children, dst_dir + children, status); + if (TF_GetCode(status) != TF_OK) return; + } + TF_SetStatus(status, TF_OK, ""); +} + +void DeleteRecursively(const TF_Filesystem* filesystem, const char* path, + uint64_t* undeleted_files, uint64_t* undeleted_dirs, + TF_Status* status) { + if (!undeleted_files || !undeleted_dirs) + return TF_SetStatus( + status, TF_INTERNAL, + "'undeleted_files' and 'undeleted_dirs' cannot be nullptr."); + *undeleted_files = 0; + *undeleted_dirs = 0; + if (!IsDirectory(filesystem, path, status)) { + *undeleted_dirs = 1; + return; + } + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string dir = path; + MaybeAppendSlash(&dir); + for (const std::string& children : childrens) { + const std::string& full_path = dir + children; + DeleteFile(filesystem, full_path.c_str(), status); + if (TF_GetCode(status) != TF_OK) { + if (IsDirectory(filesystem, full_path.c_str(), status)) + // The object is a directory marker. + (*undeleted_dirs)++; + else + (*undeleted_files)++; + } + } +} + +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status); + if (TF_GetCode(status) != TF_OK) return -1; + + int num_entries = childrens.size(); + *entries = static_cast( + plugin_memory_allocate(num_entries * sizeof((*entries)[0]))); + for (int i = 0; i < num_entries; i++) + (*entries)[i] = strdup(childrens[i].c_str()); + TF_SetStatus(status, TF_OK, ""); + return num_entries; +} + void Stat(const TF_Filesystem* filesystem, const char* path, TF_FileStatistics* stats, TF_Status* status) { std::string bucket, object; @@ -867,7 +1080,8 @@ void Stat(const TF_Filesystem* filesystem, const char* path, auto gcs_file = static_cast(filesystem->plugin_filesystem); if (object.empty()) { - auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + auto bucket_metadata = + gcs_file->gcs_client.GetBucketMetadata(bucket, gcs::Fields("")); TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); if (TF_GetCode(status) == TF_OK) { stats->is_directory = true; @@ -882,8 +1096,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path, stats->mtime_nsec = 0; return TF_SetStatus(status, TF_OK, ""); } - if (TF_GetCode(status) == TF_OK) { - auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) { + auto metadata = gcs_file->gcs_client.GetObjectMetadata( + bucket, object, gcs::Fields("size,timeStorageClassUpdated")); if (metadata) { stats->is_directory = false; stats->length = metadata.value().size(); @@ -896,6 +1111,29 @@ void Stat(const TF_Filesystem* filesystem, const char* path, } } +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + // Only validate the name. + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return -1; + + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + return stat.length; +} + +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + return strdup(uri); +} + +static void FlushCaches(const TF_Filesystem* filesystem) { + auto gcs_file = static_cast(filesystem->plugin_filesystem); + absl::ReaderMutexLock l(&gcs_file->block_cache_lock); + gcs_file->file_block_cache->Flush(); + gcs_file->stat_cache->Clear(); +} + } // namespace tf_gcs_filesystem static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, @@ -912,6 +1150,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->read_only_memory_region_ops = static_cast( + plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE)); + ops->read_only_memory_region_ops->cleanup = + tf_read_only_memory_region::Cleanup; + ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data; + ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length; + ops->filesystem_ops = static_cast( plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); ops->filesystem_ops->init = tf_gcs_filesystem::Init; @@ -921,6 +1166,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile; ops->filesystem_ops->new_appendable_file = tf_gcs_filesystem::NewAppendableFile; + ops->filesystem_ops->new_read_only_memory_region_from_file = + tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile; + ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir; + ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile; + ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir; + ops->filesystem_ops->delete_recursively = + tf_gcs_filesystem::DeleteRecursively; + ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile; + ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists; + ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory; + ops->filesystem_ops->stat = tf_gcs_filesystem::Stat; + ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName; + ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches; } void TF_InitPlugin(TF_FilesystemPluginInfo* info) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h index 973ce9e9dc2..5612d004d82 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h @@ -87,6 +87,24 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, const char* path, TF_ReadOnlyMemoryRegion* region, TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); } // namespace tf_gcs_filesystem #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc index 82c4e4b8705..e15921335ab 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) +#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890"; // We will work with content_view instead of content. @@ -94,6 +95,70 @@ class GCSFilesystemTest : public ::testing::Test { return translated_name; } + std::unique_ptr + GetWriter() { + std::unique_ptr writer( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + writer->plugin_file = nullptr; + return writer; + } + + std::unique_ptr + GetReader() { + std::unique_ptr + reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + reader->plugin_file = nullptr; + return reader; + } + + void WriteString(const std::string& path, const std::string& content) { + auto writer = GetWriter(); + tf_gcs_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Append(writer.get(), content.c_str(), content.length(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Close(writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + } + + std::string ReadAll(const std::string& path) { + auto reader = GetReader(); + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + auto file_size = + tf_gcs_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::string content; + content.resize(file_size); + auto read = tf_random_access_file::Read(reader.get(), 0, file_size, + &content[0], status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read >= 0) content.resize(read); + if (file_size != content.size()) + TF_SetStatus( + status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + " got " + + std::to_string(content.size()) + " bytes") + .c_str()); + return content; + } + protected: TF_Filesystem* filesystem_; TF_Status* status_; @@ -326,6 +391,145 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) { delete region; } +TEST_F(GCSFilesystemTest, PathExists) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("PathExists"); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_); + TF_SetStatus(status_, TF_OK, ""); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(GCSFilesystemTest, GetChildren) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string base = GetURIForPath("GetChildren"); + tf_gcs_filesystem::CreateDir(filesystem_, base.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(base, "TestFile.csv"); + WriteString(file, "test"); + EXPECT_TF_OK(status_); + + const std::string subdir = io::JoinPath(base, "SubDir"); + tf_gcs_filesystem::CreateDir(filesystem_, subdir.c_str(), status_); + EXPECT_TF_OK(status_); + const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv"); + WriteString(subfile, "test"); + EXPECT_TF_OK(status_); + + char** entries; + auto num_entries = tf_gcs_filesystem::GetChildren(filesystem_, base.c_str(), + &entries, status_); + EXPECT_TF_OK(status_); + + std::vector childrens; + for (int i = 0; i < num_entries; ++i) { + childrens.push_back(entries[i]); + } + std::sort(childrens.begin(), childrens.end()); + EXPECT_EQ(std::vector({"SubDir/", "TestFile.csv"}), childrens); +} + +TEST_F(GCSFilesystemTest, DeleteFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("DeleteFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND); +} + +TEST_F(GCSFilesystemTest, CreateDir) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string dir = GetURIForPath("CreateDir"); + tf_gcs_filesystem::CreateDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_TRUE(stat.is_directory); +} + +TEST_F(GCSFilesystemTest, DeleteDir) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string dir = GetURIForPath("DeleteDir"); + const std::string file = io::JoinPath(dir, "DeleteDirFile.csv"); + WriteString(file, "test"); + ASSERT_TF_OK(status_); + tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_EQ(TF_GetCode(status_), TF_FAILED_PRECONDITION); + + TF_SetStatus(status_, TF_OK, ""); + tf_gcs_filesystem::DeleteFile(filesystem_, file.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_); +} + +TEST_F(GCSFilesystemTest, StatFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string path = GetURIForPath("StatFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + + TF_FileStatistics stat; + tf_gcs_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, stat.length); + EXPECT_FALSE(stat.is_directory); +} + +TEST_F(GCSFilesystemTest, RenameFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string src = GetURIForPath("RenameFileSrc"); + const std::string dst = GetURIForPath("RenameFileDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test", result); +} + +TEST_F(GCSFilesystemTest, RenameFileOverwrite) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_); + const std::string src = GetURIForPath("RenameFileOverwriteSrc"); + const std::string dst = GetURIForPath("RenameFileOverwriteDst"); + + WriteString(src, "test_old"); + ASSERT_TF_OK(status_); + WriteString(dst, "test_new"); + ASSERT_TF_OK(status_); + + tf_gcs_filesystem::PathExists(filesystem_, dst.c_str(), status_); + EXPECT_TF_OK(status_); + tf_gcs_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test_old", result); +} + // These tests below are ported from // `//tensorflow/core/platform/cloud:gcs_file_system_test` TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h index 2abfb6f924b..72659a97d42 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/c/env.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" namespace tf_gcs_filesystem { @@ -65,8 +66,8 @@ class RamFileBlockCache { pruning_thread_.reset( TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); } - std::cout << "GCS file block cache is " - << (IsCacheEnabled() ? "enabled" : "disabled") << ".\n"; + TF_VLog(1, "GCS file block cache is %s.\n", + (IsCacheEnabled() ? "enabled" : "disabled")); } ~RamFileBlockCache() { diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD index 51ffd709f3d..bb97587d6d1 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD @@ -1,5 +1,5 @@ # Experimental hadoop filesystem plugin. -load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") package( licenses = ["notice"], # Apache 2.0 @@ -33,3 +33,38 @@ cc_library( "@com_google_absl//absl/synchronization", ], ) + +# This test is set to manual because it requires downloading the Hadoop +# distribution to run. To run this test: +# 1. Ensure $JAVA_HOME is set to the location of a JDK 8 installation. +# 2. Download the binary Hadoop distribution from: +# http://hadoop.apache.org/releases.html +# 3. Extract the Hadoop distribution and run: +# source libexec/hadoop-config.sh +# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within +# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real +# distributed HDFS cluster +# 5. bazel test \ +# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \ +# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \ +# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \ +# :hadoop_file_system_test +# To test against the real distributed cluster, add the following option for +# bazel test: +# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir +tf_cc_test( + name = "hadoop_filesystem_test", + srcs = [ + "hadoop_filesystem_test.cc", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":hadoop_filesystem_impl", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:stacktrace_handler", + "//tensorflow/core/platform:test", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index e94be3e83a2..b904ba292ab 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -37,11 +37,17 @@ static void plugin_memory_free(void* ptr) { free(ptr); } void ParseHadoopPath(const std::string& fname, std::string* scheme, std::string* namenode, std::string* path) { size_t scheme_end = fname.find("://") + 2; - *scheme = fname.substr(0, scheme_end + 1); + // We don't want `://` in scheme. + *scheme = fname.substr(0, scheme_end - 2); size_t nn_end = fname.find("/", scheme_end + 1); - if (nn_end == std::string::npos) return; + if (nn_end == std::string::npos) { + *namenode = fname.substr(scheme_end + 1); + *path = ""; + return; + } *namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1); - *path = fname.substr(nn_end + 1); + // We keep `/` in path. + *path = fname.substr(nn_end); } void SplitArchiveNameAndPath(std::string* path, std::string* nn, @@ -54,7 +60,7 @@ void SplitArchiveNameAndPath(std::string* path, std::string* nn, } // Case of hadoop archive. Namenode is the path to the archive. std::ostringstream namenodestream; - namenodestream << "har://" << nn + namenodestream << "har://" << *nn << path->substr(0, index_end_archive_name + 4); *nn = namenodestream.str(); path->erase(0, index_end_archive_name + 4); @@ -247,8 +253,8 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, char* dst = buffer; bool eof_retried = false; - int64_t r = 0; - while (TF_GetCode(status) == TF_OK && !eof_retried) { + int64_t read = 0; + while (TF_GetCode(status) == TF_OK && n > 0) { // We lock inside the loop rather than outside so we don't block other // concurrent readers. absl::MutexLock l(&hdfs_file->mu); @@ -257,12 +263,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, // of int32. -2 offset can avoid JVM OutOfMemoryError. size_t read_n = (std::min)(n, static_cast(std::numeric_limits::max() - 2)); - r = libhdfs->hdfsPread(fs, handle, static_cast(offset), dst, - static_cast(read_n)); + int64_t r = libhdfs->hdfsPread(fs, handle, static_cast(offset), + dst, static_cast(read_n)); if (r > 0) { dst += r; n -= r; offset += r; + read += r; } else if (!eof_retried && r == 0) { // Always reopen the file upon reaching EOF to see if there's more data. // If writers are streaming contents while others are concurrently @@ -274,11 +281,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, TF_SetStatusFromIOError(status, errno, path); return -1; } - handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); - if (handle == nullptr) { + hdfs_file->handle = + libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); + if (hdfs_file->handle == nullptr) { TF_SetStatusFromIOError(status, errno, path); return -1; } + handle = hdfs_file->handle; eof_retried = true; } else if (eof_retried && r == 0) { TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); @@ -288,7 +297,7 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, TF_SetStatusFromIOError(status, errno, path); } } - return r; + return read; } } // namespace tf_random_access_file @@ -308,7 +317,7 @@ typedef struct HDFSFile { handle(handle) {} } HDFSFile; -static void Cleanup(TF_WritableFile* file) { +void Cleanup(TF_WritableFile* file) { auto hdfs_file = static_cast(file->plugin_file); hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle); hdfs_file->fs = nullptr; @@ -433,6 +442,23 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, std::string scheme, namenode, hdfs_path; ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY, 0, 0, 0); + if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + file->plugin_file = + new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle); + TF_SetStatus(status, TF_OK, ""); +} + +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_WRONLY | O_APPEND, 0, 0, 0); if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path); @@ -442,6 +468,202 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, TF_SetStatus(status, TF_OK, ""); } +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status) { + // hadoopReadZero() technically supports this call with the following + // caveats: + // - It only works up to 2 GB. We'd have to Stat() the file to ensure that + // it fits. + // - If not on the local filesystem, the entire file will be read, making + // it inefficient for callers that assume typical mmap() behavior. + TF_SetStatus(status, TF_UNIMPLEMENTED, + "HDFS does not support ReadOnlyMemoryRegion"); +} + +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0) + TF_SetStatus(status, TF_OK, ""); + else + TF_SetStatus(status, TF_NOT_FOUND, + (std::string(path) + " not found").c_str()); +} + +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str()); + if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + stats->length = static_cast(info->mSize); + stats->mtime_nsec = static_cast(info->mLastMod) * 1e9; + stats->is_directory = info->mKind == kObjectKindDirectory; + libhdfs->hdfsFreeFileInfo(info, 1); + TF_SetStatus(status, TF_OK, ""); +} + +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return -1; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str()); + if (info == nullptr) { + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + + TF_SetStatus(status, TF_OK, ""); + auto size = static_cast(info->mSize); + libhdfs->hdfsFreeFileInfo(info, 1); + return size; +} + +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + // Count the number of entries in the directory, and only delete if it's + // non-empty. This is consistent with the interface, but note that there's + // a race condition where a file may be added after this check, in which + // case the directory will still be deleted. + int entries = 0; + auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries); + if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries); + + // Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty + // folder, especially for Kerberos enable setup, EAGAIN is quite common when + // the call is actually successful. Check again by Stat. + if (info == nullptr && errno != 0) { + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + if (TF_GetCode(status) != TF_OK) return; + } + + if (entries > 0) + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Cannot delete a non-empty directory."); + + if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; + ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src); + ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst); + + if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 && + libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0) + return TF_SetStatusFromIOError(status, errno, dst); + + if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) != + 0) + TF_SetStatusFromIOError(status, errno, src); + else + TF_SetStatus(status, TF_OK, ""); +} + +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return -1; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + // hdfsListDirectory returns nullptr if the directory is empty. Do a separate + // check to verify the directory exists first. + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + if (TF_GetCode(status) != TF_OK) return -1; + + int num_entries = 0; + auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries); + if (info == nullptr) { + if (stat.is_directory) { + // Assume it's an empty directory. + TF_SetStatus(status, TF_OK, ""); + return 0; + } + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + *entries = static_cast( + plugin_memory_allocate(num_entries * sizeof((*entries)[0]))); + auto BaseName = [](const std::string& name) { + return name.substr(name.find_last_of('/') + 1); + }; + for (int i = 0; i < num_entries; i++) { + (*entries)[i] = strdup(BaseName(info[i].mName).c_str()); + } + libhdfs->hdfsFreeFileInfo(info, num_entries); + TF_SetStatus(status, TF_OK, ""); + return num_entries; +} + // TODO(vnvo2409): Implement later } // namespace tf_hadoop_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h index 850cefe0231..8de66c05bac 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h @@ -15,7 +15,62 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ +#include + #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/tf_status.h" +void ParseHadoopPath(const std::string& fname, std::string* scheme, + std::string* namenode, std::string* path); +void SplitArchiveNameAndPath(std::string* path, std::string* nn, + TF_Status* status); +class LibHDFS; + +namespace tf_random_access_file { +void Cleanup(TF_RandomAccessFile* file); +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status); +} // namespace tf_random_access_file + +namespace tf_writable_file { +void Cleanup(TF_WritableFile* file); +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status); +int64_t Tell(const TF_WritableFile* file, TF_Status* status); +void Sync(const TF_WritableFile* file, TF_Status* status); +void Flush(const TF_WritableFile* file, TF_Status* status); +void Close(const TF_WritableFile* file, TF_Status* status); +} // namespace tf_writable_file + +namespace tf_hadoop_filesystem { +void Init(TF_Filesystem* filesystem, TF_Status* status); +void Cleanup(TF_Filesystem* filesystem); +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status); +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +} // namespace tf_hadoop_filesystem + #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc new file mode 100644 index 00000000000..77079fb5325 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc @@ -0,0 +1,418 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h" + +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stacktrace_handler.h" +#include "tensorflow/core/platform/test.h" +#include "third_party/hadoop/hdfs.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) +#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) + +namespace tensorflow { +namespace { + +class HadoopFileSystemTest : public ::testing::Test { + public: + void SetUp() override { + status_ = TF_NewStatus(); + filesystem_ = new TF_Filesystem; + tf_hadoop_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + } + void TearDown() override { + TF_DeleteStatus(status_); + tf_hadoop_filesystem::Cleanup(filesystem_); + delete filesystem_; + } + + std::string TmpDir(const std::string& path) { + char* test_dir = getenv("HADOOP_TEST_TMPDIR"); + if (test_dir != nullptr) { + return io::JoinPath(std::string(test_dir), path); + } else { + return "file://" + io::JoinPath(testing::TmpDir(), path); + } + } + + std::unique_ptr + GetWriter() { + std::unique_ptr writer( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + writer->plugin_file = nullptr; + return writer; + } + + std::unique_ptr + GetReader() { + std::unique_ptr + reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + reader->plugin_file = nullptr; + return reader; + } + + void WriteString(const std::string& path, const std::string& content) { + auto writer = GetWriter(); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), + writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Append(writer.get(), content.c_str(), content.length(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Close(writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + } + + std::string ReadAll(const std::string& path) { + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + auto file_size = + tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::string content; + content.resize(file_size); + auto read = tf_random_access_file::Read(reader.get(), 0, file_size, + &content[0], status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read >= 0) content.resize(read); + if (file_size != content.size()) + TF_SetStatus( + status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + " got " + + std::to_string(content.size()) + " bytes") + .c_str()); + return content; + } + + protected: + TF_Filesystem* filesystem_; + TF_Status* status_; +}; + +TEST_F(HadoopFileSystemTest, RandomAccessFile) { + const std::string path = TmpDir("RandomAccessFile"); + const std::string content = "abcdefghijklmn"; + + WriteString(path, content); + ASSERT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content.size(), result.size()); + EXPECT_EQ(content, result); + + result.clear(); + result.resize(4); + read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, result.size()); + EXPECT_EQ(content.substr(2, 4), result); +} + +TEST_F(HadoopFileSystemTest, WritableFile) { + auto writer = GetWriter(); + const std::string path = TmpDir("WritableFile"); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"), + status_); + EXPECT_TF_OK(status_); + auto pos = tf_writable_file::Tell(writer.get(), status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(pos, 9); + + tf_writable_file::Append(writer.get(), "content2", strlen("content2"), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Sync(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); + + auto content = ReadAll(path); + EXPECT_TF_OK(status_); + EXPECT_EQ("content1,content2", content); +} + +TEST_F(HadoopFileSystemTest, PathExists) { + const std::string path = TmpDir("PathExists"); + tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_); + TF_SetStatus(status_, TF_OK, ""); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_hadoop_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, GetChildren) { + const std::string base = TmpDir("GetChildren"); + tf_hadoop_filesystem::CreateDir(filesystem_, base.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(base, "TestFile.csv"); + WriteString(file, "test"); + EXPECT_TF_OK(status_); + + const std::string subdir = io::JoinPath(base, "SubDir"); + tf_hadoop_filesystem::CreateDir(filesystem_, subdir.c_str(), status_); + EXPECT_TF_OK(status_); + const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv"); + WriteString(subfile, "test"); + EXPECT_TF_OK(status_); + + char** entries; + auto num_entries = tf_hadoop_filesystem::GetChildren( + filesystem_, base.c_str(), &entries, status_); + EXPECT_TF_OK(status_); + + std::vector childrens; + for (int i = 0; i < num_entries; ++i) { + childrens.push_back(entries[i]); + } + std::sort(childrens.begin(), childrens.end()); + EXPECT_EQ(std::vector({"SubDir", "TestFile.csv"}), childrens); +} + +TEST_F(HadoopFileSystemTest, DeleteFile) { + const std::string path = TmpDir("DeleteFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, GetFileSize) { + const std::string path = TmpDir("GetFileSize"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + auto file_size = + tf_hadoop_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, file_size); +} + +TEST_F(HadoopFileSystemTest, CreateDirStat) { + const std::string path = TmpDir("CreateDirStat"); + tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_TRUE(stat.is_directory); +} + +TEST_F(HadoopFileSystemTest, DeleteDir) { + const std::string path = TmpDir("DeleteDir"); + tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_); + EXPECT_NE(TF_GetCode(status_), TF_OK); + tf_hadoop_filesystem::CreateDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + tf_hadoop_filesystem::DeleteDir(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_NE(TF_GetCode(status_), TF_OK); +} + +TEST_F(HadoopFileSystemTest, RenameFile) { + const std::string src = TmpDir("RenameFileSrc"); + const std::string dst = TmpDir("RenameFileDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), + status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test", result); +} + +TEST_F(HadoopFileSystemTest, RenameFileOverwrite) { + const std::string src = TmpDir("RenameFileOverwriteSrc"); + const std::string dst = TmpDir("RenameFileOverwriteDst"); + + WriteString(src, "test_old"); + ASSERT_TF_OK(status_); + WriteString(dst, "test_new"); + ASSERT_TF_OK(status_); + + tf_hadoop_filesystem::PathExists(filesystem_, dst.c_str(), status_); + EXPECT_TF_OK(status_); + tf_hadoop_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), + status_); + EXPECT_TF_OK(status_); + + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test_old", result); +} + +TEST_F(HadoopFileSystemTest, StatFile) { + const std::string path = TmpDir("StatFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + TF_FileStatistics stat; + tf_hadoop_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, stat.length); + EXPECT_FALSE(stat.is_directory); +} + +TEST_F(HadoopFileSystemTest, WriteWhileReading) { + const std::string path = TmpDir("WriteWhileReading"); + // Skip the test if we're not testing on HDFS. Hadoop's local filesystem + // implementation makes no guarantees that writable files are readable while + // being written. + if (path.find_first_of("hdfs://") != 0) GTEST_SKIP(); + + auto writer = GetWriter(); + tf_hadoop_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + + const std::string content1 = "content1"; + tf_writable_file::Append(writer.get(), content1.c_str(), content1.size(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content1.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content1, result); + + const std::string content2 = "content2"; + tf_writable_file::Append(writer.get(), content2.c_str(), content2.size(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + + result.resize(content2.size()); + read = tf_random_access_file::Read(reader.get(), content1.size(), + content2.size(), &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content2, result); + + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(HadoopFileSystemTest, HarSplit) { + const std::string har_path = + "har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode); + EXPECT_EQ("/dir0/dir1/file.txt", path); +} + +TEST_F(HadoopFileSystemTest, NoHarExtension) { + const std::string har_path = + "har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT) << TF_Message(status_); +} + +TEST_F(HadoopFileSystemTest, HarRootPath) { + const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har"; + std::string scheme, namenode, path; + ParseHadoopPath(har_path, &scheme, &namenode, &path); + EXPECT_EQ("har", scheme); + EXPECT_EQ("hdfs-root", namenode); + EXPECT_EQ("/user/j.doe/my_archive.har", path); + SplitArchiveNameAndPath(&path, &namenode, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", namenode); + EXPECT_EQ("/", path); +} + +TEST_F(HadoopFileSystemTest, WriteLargeFile) { + if (std::getenv("HADOOP_TEST_LARGE_FILE") != "1") GTEST_SKIP(); + const std::string path = TmpDir("WriteLargeFile"); + const size_t file_size = + static_cast(std::numeric_limits::max()) + 1024; + // Fake a test string. + std::string source(file_size, {}); + for (size_t i = 0; i < file_size; ++i) source[i] = (i % 128); + WriteString(path, source); + ASSERT_TF_OK(status_); + auto result = ReadAll(path); + EXPECT_TF_OK(status_); + EXPECT_EQ(source, result); +} +// NewAppendableFile() is not testable. Local filesystem maps to +// ChecksumFileSystem in Hadoop, where appending is an unsupported operation. + +} // namespace +} // namespace tensorflow + +GTEST_API_ int main(int argc, char** argv) { + tensorflow::testing::InstallStacktraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD index 56bd3b4a75c..a2108d06cbb 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD @@ -26,6 +26,8 @@ cc_library( }), deps = [ ":aws_crypto", + ":aws_logging", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@aws", @@ -45,6 +47,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "aws_logging", + srcs = ["aws_logging.cc"], + hdrs = ["aws_logging.h"], + deps = [ + "//tensorflow/c:logging", + "@aws", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + tf_cc_test( name = "s3_filesystem_test", srcs = [ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc new file mode 100644 index 00000000000..353b733fd25 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc @@ -0,0 +1,159 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" + +#include +#include +#include + +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/logging.h" + +static constexpr char kAWSLoggingTag[] = "AWSLogging"; + +static const std::map + log_levels_string_to_aws = { + {"off", Aws::Utils::Logging::LogLevel::Off}, + {"fatal", Aws::Utils::Logging::LogLevel::Fatal}, + {"error", Aws::Utils::Logging::LogLevel::Error}, + {"warn", Aws::Utils::Logging::LogLevel::Warn}, + {"info", Aws::Utils::Logging::LogLevel::Info}, + {"debug", Aws::Utils::Logging::LogLevel::Debug}, + {"trace", Aws::Utils::Logging::LogLevel::Trace}}; + +static const std::map + log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info}, + {1, Aws::Utils::Logging::LogLevel::Warn}, + {2, Aws::Utils::Logging::LogLevel::Error}, + {3, Aws::Utils::Logging::LogLevel::Fatal}}; + +namespace tf_s3_filesystem { + +AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level) + : log_level_(log_level) {} + +void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message) { + if (message == "Initializing Curl library") return; + switch (log_level) { + case Aws::Utils::Logging::LogLevel::Info: + TF_Log(TF_INFO, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Warn: + TF_Log(TF_WARNING, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Error: + TF_Log(TF_ERROR, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Fatal: + TF_Log(TF_FATAL, message.c_str()); + break; + default: + // this will match for DEBUG, TRACE + TF_Log(TF_INFO, message.c_str()); + break; + } +} + +void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) { + char buffer[256]; + va_list args; + va_start(args, format); + vsnprintf(buffer, 256, format, args); + va_end(args); + LogMessage(log_level, buffer); +} + +void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level, + const char* tag, + const Aws::OStringStream& message_stream) { + LogMessage(log_level, message_stream.rdbuf()->str().c_str()); +} + +void AWSLogSystem::Flush() { return; } + +static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) { + // Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum + // values for the levels + if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) { + return log_levels_tf_to_aws.at(level); + } else { + // default to fatal + return Aws::Utils::Logging::LogLevel::Fatal; + } +} + +static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() { + // defaults to FATAL log level for the AWS SDK + // this is because many normal tensorflow operations are logged as errors in + // the AWS SDK such as checking if a file exists can log an error in AWS SDK + // if the file does not actually exist. Another such case is when reading a + // file till the end, TensorFlow expects to see an InvalidRange exception at + // the end, but this would be an error in the AWS SDK. This confuses users, + // hence the default setting. + Aws::Utils::Logging::LogLevel log_level = + Aws::Utils::Logging::LogLevel::Fatal; + + const char* aws_env_var_val = getenv("AWS_LOG_LEVEL"); + if (aws_env_var_val != nullptr) { + std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val)); + std::istringstream ss(maybe_integer_str); + int level; + ss >> level; + if (ss.fail()) { + // wasn't a number + // expecting a string + std::string level_str = maybe_integer_str; + if (log_levels_string_to_aws.find(level_str) != + log_levels_string_to_aws.end()) { + log_level = log_levels_string_to_aws.at(level_str); + } + } else { + // backwards compatibility + // valid number, but this number follows the standard TensorFlow log + // levels need to convert this to AWS SDK logging level number + log_level = TfLogLevelToAwsLogLevel(level); + } + } + return log_level; +} + +static bool initialized = false; +ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit); +void AWSLogSystem::InitializeAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (!initialized) { + Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared( + kAWSLoggingTag, ParseAwsLogLevelFromEnv())); + initialized = true; + return; + } +} + +void AWSLogSystem::ShutdownAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (initialized) { + Aws::Utils::Logging::ShutdownAWSLogging(); + initialized = false; + return; + } +} + +} // namespace tf_s3_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h new file mode 100644 index 00000000000..afecd7e5e62 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ + +#include +#include + +#include +#include + +namespace tf_s3_filesystem { + +class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface { + public: + static void InitializeAWSLogging(); + static void ShutdownAWSLogging(); + + explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level); + virtual ~AWSLogSystem() = default; + + // Gets the currently configured log level. + Aws::Utils::Logging::LogLevel GetLogLevel(void) const override { + return log_level_; + } + + // Set a new log level. This has the immediate effect of changing the log. + void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) { + log_level_.store(log_level); + } + + // Does a printf style output to ProcessFormattedStatement. Don't use this, + // it's unsafe. See LogStream. + void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) override; + + // Writes the stream to ProcessFormattedStatement. + void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const Aws::OStringStream& messageStream) override; + + // Flushes the buffered messages if the logger supports buffering + void Flush() override; + + private: + void LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message); + std::atomic log_level_; +}; + +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc index 7e1b36f2dcc..9ff07633f2a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc @@ -38,6 +38,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for S3 environments. @@ -186,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) { absl::MutexLock l(&s3_file->initialization_lock); if (s3_file->s3_client.get() == nullptr) { + tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging(); + Aws::SDKOptions options; options.cryptoOptions.sha256Factory_create_fn = []() { return Aws::MakeShared( @@ -250,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) { delete s3_client; Aws::SDKOptions options; Aws::ShutdownAPI(options); + tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging(); } } @@ -281,6 +286,7 @@ void Cleanup(TF_RandomAccessFile* file) { static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { + TF_VLog(3, "ReadFile using S3Client\n"); Aws::S3::Model::GetObjectRequest get_object_request; get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object); Aws::String bytes = @@ -306,12 +312,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { + TF_VLog(3, "Using TransferManager\n"); auto create_download_stream = [&]() { return Aws::New( "S3ReadStream", Aws::New( "S3ReadStream", reinterpret_cast(buffer), n)); }; + TF_VLog(3, "Created stream to read with transferManager\n"); auto handle = s3_file->transfer_manager->DownloadFile( s3_file->bucket, s3_file->object, offset, n, create_download_stream); handle->WaitUntilFinished(); @@ -322,6 +330,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE && retries++ < kDownloadRetries) { // Only failed parts will be downloaded again. + TF_VLog( + 1, + "Retrying read of s3://%s/%s after failure. Current retry count: %u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), retries); s3_file->transfer_manager->RetryDownload(handle); handle->WaitUntilFinished(); } @@ -341,6 +353,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { auto s3_file = static_cast(file->plugin_file); + TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n); if (s3_file->use_multi_part_download) return ReadS3TransferManager(s3_file, offset, n, buffer, status); else @@ -416,6 +430,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); return; } + TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(), + s3_file->object.c_str()); auto position = static_cast(s3_file->outfile->tellp()); auto handle = s3_file->transfer_manager->UploadFile( s3_file->outfile, s3_file->bucket, s3_file->object, @@ -426,6 +442,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) { while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED && retries++ < kUploadRetries) { // if multipart upload was used, only the failed parts will be re-sent + TF_VLog(1, + "Retrying upload of s3://%s/%s after failure. Current retry count: " + "%u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), retries); s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle); handle->WaitUntilFinished(); } @@ -613,6 +633,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path, TF_FileStatistics* stats, TF_Status* status) { + TF_VLog(1, "Stat on path: %s\n", path); Aws::String bucket, object; ParseS3Path(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -737,6 +758,8 @@ static void SimpleCopyFile(const Aws::String& source, const Aws::String& bucket_dst, const Aws::String& object_dst, S3File* s3_file, TF_Status* status) { + TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(), + object_dst.c_str()); Aws::S3::Model::CopyObjectRequest copy_object_request; copy_object_request.WithCopySource(source) .WithBucket(bucket_dst) @@ -801,6 +824,8 @@ static void MultiPartCopy(const Aws::String& source, const Aws::String& object_dst, const size_t num_parts, const uint64_t file_size, S3File* s3_file, TF_Status* status) { + TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(), + object_dst.c_str()); Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request; create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst); @@ -827,6 +852,8 @@ static void MultiPartCopy(const Aws::String& source, auto chunk_size = s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD]; + TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(), + num_parts, chunk_size); size_t retries = 0; while (retries++ < 3) { // Queue up parts. @@ -891,6 +918,9 @@ static void MultiPartCopy(const Aws::String& source, status); } else { // Retry. + TF_Log(TF_ERROR, + "Retrying failed copy of part %u due to an error with S3\n", + part_number); num_finished_parts--; } } @@ -967,6 +997,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, void DeleteFile(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "DeleteFile: %s\n", path); Aws::String bucket, object; ParseS3Path(path, false, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -985,6 +1016,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "CreateDir: %s\n", path); Aws::String bucket, object; ParseS3Path(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -1026,6 +1058,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "DeleteDir: %s\n", path); Aws::String bucket, object; ParseS3Path(path, false, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -1060,6 +1093,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path, void RenameFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_Status* status) { + TF_VLog(1, "RenameFile from: %s to %s\n", src, dst); Aws::String bucket_src, object_src; ParseS3Path(src, false, &bucket_src, &object_src, status); if (TF_GetCode(status) != TF_OK) return; @@ -1120,6 +1154,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src, int GetChildren(const TF_Filesystem* filesystem, const char* path, char*** entries, TF_Status* status) { + TF_VLog(1, "GetChildren for path: %s\n", path); Aws::String bucket, prefix; ParseS3Path(path, true, &bucket, &prefix, status); if (TF_GetCode(status) != TF_OK) return -1; diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index e3acdf7e2c3..5386c0cf3f7 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -3,6 +3,24 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "array_grad", + srcs = ["array_grad.cc"], + hdrs = [ + "array_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/core/lib/llvm_rtti", + ], +) + cc_library( name = "math_grad", srcs = ["math_grad.cc"], @@ -13,11 +31,63 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + ], +) + +cc_library( + name = "nn_grad", + srcs = ["nn_grad.cc"], + hdrs = [ + "nn_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "gradients", + hdrs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_grad", + ":math_grad", + ":nn_grad", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", ], ) diff --git a/tensorflow/c/experimental/gradients/array_grad.cc b/tensorflow/c/experimental/gradients/array_grad.cc new file mode 100644 index 00000000000..069209a4b6b --- /dev/null +++ b/tensorflow/c/experimental/gradients/array_grad.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/array_grad.h" + +namespace tensorflow { +namespace gradients { +namespace { +using std::vector; +class IdentityNGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + grad_outputs->resize(grad_inputs.size(), nullptr); + for (int i = 0; i < grad_inputs.size(); i++) { + auto grad_input = grad_inputs[i]; + // TODO(srbs): Should we add a copy contructor to AbstractTensorHandle + // that takes care of this similar to `Tensor`? + if (grad_input) { + grad_input->Ref(); + } + (*grad_outputs)[i] = grad_input; + } + return Status::OK(); + } + ~IdentityNGradientFunction() override {} +}; +} // namespace + +BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) { + auto gradient_function = new IdentityNGradientFunction; + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/array_grad.h b/tensorflow/c/experimental/gradients/array_grad.h new file mode 100644 index 00000000000..edeeb5fcb4a --- /dev/null +++ b/tensorflow/c/experimental/gradients/array_grad.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +BackwardFunction* IdentityNRegisterer(const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 47bd8cce23d..c2aa9caf814 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -14,9 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" -using tensorflow::ops::Identity; +using std::vector; +using tensorflow::ops::Conj; +using tensorflow::ops::MatMul; +using tensorflow::ops::Mul; namespace tensorflow { namespace gradients { @@ -24,30 +31,184 @@ namespace { class AddGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, - absl::Span grad_inputs, - std::vector* grad_outputs) override { + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { grad_outputs->resize(2); - std::vector identity_outputs(1); - // TODO(b/145674566): Handle name unification in tracing code. // TODO(b/161805092): Support broadcasting. - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - "Identity0")); - (*grad_outputs)[0] = identity_outputs[0]; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - "Identity1")); - (*grad_outputs)[1] = identity_outputs[0]; + + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[1] = grad_inputs[0]; + + (*grad_outputs)[0]->Ref(); + (*grad_outputs)[1]->Ref(); return Status::OK(); } ~AddGradientFunction() override {} }; +class ExpGradientFunction : public GradientFunction { + public: + explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) { + exp->Ref(); + } + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + vector conj_outputs(1); + std::string name = "Conj_Exp_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()}, + absl::MakeSpan(conj_outputs), name.c_str())); + AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]); + grad_outputs->resize(1); + + name = "Mul_Exp_Grad"; + TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~ExpGradientFunction() override {} + + private: + AbstractTensorHandlePtr exp_; +}; + +class MatMulGradientFunction : public GradientFunction { + public: + explicit MatMulGradientFunction(vector f_inputs, + AttrBuilder f_attrs) + : forward_inputs(f_inputs), forward_attrs(f_attrs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a matmul op A*B, the gradients are: + * + * dA = U * B.T + * dB = A.T * U + * + * where A.T means `transpose(A)` + */ + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + grad_outputs->resize(2); + + // Get transpose attrs + bool t_a; + TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a)); + + bool t_b; + TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b)); + + // Conj each input + vector conj_outputs(1); + std::string name = "Conj_A_MatMul_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]}, + absl::MakeSpan(conj_outputs), name.c_str())); + + AbstractTensorHandle* A = conj_outputs[0]; + + name = "Conj_B_MatMul_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]}, + absl::MakeSpan(conj_outputs), name.c_str())); + + AbstractTensorHandle* B = conj_outputs[0]; + + // Calc Grad + vector matmul_A_outputs(1); + vector matmul_B_outputs(1); + std::string name_grad_A = "MatMul_Grad_A"; + std::string name_grad_B = "MatMul_Grad_B"; + if (!t_a && !t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ false)); + } else if (!t_a && t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ false)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ false)); + + } else if (t_a && !t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ false)); + } else { // t_a && t_b + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ true)); + } + + // Gradient for A + (*grad_outputs)[0] = matmul_A_outputs[0]; + + // Gradient for B + (*grad_outputs)[1] = matmul_B_outputs[0]; + return Status::OK(); + } + ~MatMulGradientFunction() override {} + + private: + vector forward_inputs; + AttrBuilder forward_attrs; +}; + } // namespace -GradientFunction* AddRegisterer(const ForwardOperation& op) { - return new AddGradientFunction; +BackwardFunction* AddRegisterer(const ForwardOperation& op) { + auto gradient_function = new AddGradientFunction; + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); } + +BackwardFunction* ExpRegisterer(const ForwardOperation& op) { + auto gradient_function = new ExpGradientFunction(op.outputs[0]); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* MatMulRegisterer(const ForwardOperation& op) { + auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 473253f9b27..205419e1201 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -19,8 +19,10 @@ limitations under the License. namespace tensorflow { namespace gradients { -GradientFunction* AddRegisterer(const ForwardOperation& op); +BackwardFunction* AddRegisterer(const ForwardOperation& op); +BackwardFunction* ExpRegisterer(const ForwardOperation& op); +BackwardFunction* MatMulRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ \ No newline at end of file diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc new file mode 100644 index 00000000000..64532c8ffc0 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +using std::vector; +using tensorflow::ops::Mul; +using tensorflow::ops::ReluGrad; + +namespace tensorflow { +namespace gradients { +namespace { + +class ReluGradientFunction : public GradientFunction { + public: + explicit ReluGradientFunction(vector f_outputs) + : forward_outputs(f_outputs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + AbstractTensorHandle* activations = forward_outputs[0]; + grad_outputs->resize(1); + vector relugrad_outputs(1); + + // Calculate Grad + std::string name = "relu_grad"; + + TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations}, + absl::MakeSpan(relugrad_outputs), + name.c_str())); + (*grad_outputs)[0] = relugrad_outputs[0]; + + return Status::OK(); + } + ~ReluGradientFunction() override {} + + private: + vector forward_outputs; +}; + +Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec, + AbstractTensorHandle* mat, + absl::Span outputs) { + if (!isa(ctx)) { + // TODO(b/168850692): Fix this. + return errors::Unimplemented( + "BroadcastMul is not supported in tracing mode yet."); + } + auto imm_ctx = dyn_cast(ctx); + AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1)); + ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get())); + vector expand_dims_outputs(1); + TF_RETURN_IF_ERROR(ops::ExpandDims(ctx, {vec, dim.get()}, + absl::MakeSpan(expand_dims_outputs), + "ExpandDims")); + TF_RETURN_IF_ERROR( + ops::Mul(ctx, {expand_dims_outputs[0], mat}, outputs, "Mul")); + expand_dims_outputs[0]->Unref(); + return Status::OK(); +} + +class SparseSoftmaxCrossEntropyWithLogitsGradientFunction + : public GradientFunction { + public: + explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction( + vector f_outputs) + : forward_outputs(f_outputs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + grad_outputs->resize(2); + + // Grad for Softmax Input + vector mul_outputs(1); + TF_RETURN_IF_ERROR(BroadcastMul( + ctx->ctx, grad_inputs[0], forward_outputs[1], + absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad + (*grad_outputs)[0] = mul_outputs[0]; + + // Grad for labels is null + (*grad_outputs)[1] = nullptr; + + return Status::OK(); + } + ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {} + + private: + vector forward_outputs; +}; + +} // namespace + +BackwardFunction* ReluRegisterer(const ForwardOperation& op) { + auto gradient_function = new ReluGradientFunction(op.outputs); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( + const ForwardOperation& op) { + auto gradient_function = + new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs); + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h new file mode 100644 index 00000000000..034f20d7325 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +BackwardFunction* ReluRegisterer(const ForwardOperation& op); +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( + const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index 312709f4332..d2c22e65f80 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -15,6 +15,7 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", @@ -22,3 +23,80 @@ cc_library( "//tensorflow/core/platform:errors", ], ) + +cc_library( + name = "math_ops", + srcs = [ + "math_ops.cc", + ], + hdrs = [ + "math_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_ops", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core:framework", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) + +cc_library( + name = "nn_ops", + srcs = [ + "nn_ops.cc", + ], + hdrs = [ + "nn_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) + +cc_library( + name = "ops", + hdrs = [ + "array_ops.h", + "math_ops.h", + "nn_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_ops", + ":math_ops", + ":nn_ops", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core/lib/llvm_rtti", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "array_ops.h", + "math_ops.h", + "nn_ops.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index e38b00088cf..6ea7a0b73f8 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -14,11 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace ops { -// Creates an Identity op. + Status Identity(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { @@ -34,5 +35,51 @@ Status Identity(AbstractContext* ctx, return identity_op->Execute(outputs, &num_retvals); } +Status ZerosLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr z_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); + if (isa(z_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(z_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0])); + int num_retvals = 1; + return z_op->Execute(outputs, &num_retvals); +} + +Status Shape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr shape_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(shape_op->Reset("Shape", /*raw_device_name=*/nullptr)); + + if (isa(shape_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(shape_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input + int num_retvals = 1; + TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status ExpandDims(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("ExpandDims", /*raw_device_name=*/nullptr)); + if (isa(op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(op->AddInput(inputs[1])); + int num_retvals = 1; + return op->Execute(outputs, &num_retvals); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index 8a9db484c2e..a2179d3f137 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -15,16 +15,30 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ +#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" namespace tensorflow { namespace ops { + Status Identity(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + +Status ZerosLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status Shape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status ExpandDims(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc new file mode 100644 index 00000000000..2c6d01b5e21 --- /dev/null +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/ops/math_ops.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +namespace tensorflow { +namespace ops { +using tensorflow::tracing::TracingOperation; + +Status Mul(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr mul_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr)); + if (isa(mul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(mul_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1])); + int num_retvals = 1; + return mul_op->Execute(outputs, &num_retvals); +} + +Status Conj(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + auto dtype = inputs[0]->DataType(); + if (DataTypeIsFloating(BaseType(dtype)) || + DataTypeIsInteger(BaseType(dtype))) { + TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name)); + } else { + return errors::Unimplemented("Conj does not support complex types yet."); + } + return Status::OK(); +} + +Status Add(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr add_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr)); + + if (isa(add_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(add_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status Sub(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sub_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr)); + + if (isa(sub_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sub_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status MatMul(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a = false, bool transpose_b = false) { + AbstractOperationPtr matmul_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr)); + + if (isa(matmul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(matmul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1])); + + TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a)); + TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b)); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status Neg(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr neg_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr)); + if (isa(neg_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(neg_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0])); + + int num_retvals = 1; + return neg_op->Execute(outputs, &num_retvals); +} + +Status Sum(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sum_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr)); + + if (isa(sum_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sum_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices + + int num_retvals = 1; + TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr div_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr)); + + if (isa(div_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(div_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y + + int num_retvals = 1; + TF_RETURN_IF_ERROR(div_op->Execute( + outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0) + return Status::OK(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h new file mode 100644 index 00000000000..004b8f2bb4d --- /dev/null +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" + +namespace tensorflow { +namespace ops { +Status Mul(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status Conj(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status Add(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status MatMul(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b); + +Status Neg(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status Sum(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status Sub(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc new file mode 100644 index 00000000000..bcc5586c0ef --- /dev/null +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -0,0 +1,85 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/ops/nn_ops.h" + +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace ops { + +// Softmax Loss given scores and labels, used by the SoftMaxLossGradient +Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sm_loss_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits", + /*raw_device_name=*/nullptr)); + + if (isa(sm_loss_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sm_loss_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores + TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels + + // Outputs will contain: [loss_vals, gradients]. + int num_retvals = 2; + TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// Computes Relu gradient given input features +Status ReluGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr relugrad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr)); + + if (isa(relugrad_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(relugrad_op.get()) + ->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads + TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs + + int num_retvals = 1; + TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status Relu(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr relu_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(relu_op->Reset("Relu", /*raw_device_name=*/nullptr)); + + if (isa(relu_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(relu_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(relu_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(relu_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h new file mode 100644 index 00000000000..142b74aff0e --- /dev/null +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ + +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +namespace tensorflow { +namespace ops { + +Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status ReluGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status Relu(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 8078758328c..2feb7c1b33e 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -44,7 +44,9 @@ cc_library( ], deps = [ ":concrete_function", + ":signature_def_function", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -70,6 +72,26 @@ cc_library( ], ) +cc_library( + name = "signature_def_function", + hdrs = [ + "signature_def_function.h", + ], + deps = [ + ":signature_def_function_metadata", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "signature_def_function_metadata", + hdrs = [ + "signature_def_function_metadata.h", + ], +) + cc_library( name = "test_utils", testonly = True, @@ -115,6 +137,7 @@ cc_library( ":concrete_function", ":saved_model_api", ":saved_model_utils", + ":signature_def_function", "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_tensor_handle", @@ -206,13 +229,30 @@ tf_cc_test( "//tensorflow/c/experimental/saved_model/core/revived_types:constant", "//tensorflow/core:all_kernels", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/common_runtime:core_cpu_lib", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:tensor_handle", + ], +) + +tf_cc_test( + name = "signature_flattening_test", + srcs = [ + "signature_flattening_test.cc", + ], + deps = [ + ":saved_model_utils", + "//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime/eager:core", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index da3a64b91a3..48a20ef7768 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -26,10 +26,14 @@ limitations under the License. namespace tensorflow { -// Note that ConcreteFunctions's lifetimes are effectively bound -// to the SavedModel they are loaded from, since they retain pointers -// to the TensorHandles owned by the SavedModel, and the FunctionDef -// of the SavedModel. +// ConcreteFunctions correspond to an instance of a tf.function with a known set +// of inputs (either through get_concrete_function) or an input_signature. +// ConcreteFunction attempts to preserve the user-facing semantics of the +// tf.function python API and can take a limited set of types as arguments +// (to be modeled in tensorflow::Value), not just Tensors. +// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they +// are loaded from, since they retain pointers to the TensorHandles owned by the +// SavedModel, and the FunctionDef of the SavedModel. // Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock // TFRT integration with TF Serving. Do not add more virtual implementations of // this class. Eventually we want to remove this virtual base class indirection @@ -39,8 +43,8 @@ class ConcreteFunction { virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) = 0; + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index 492a58f816d..be9ffff99ff 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -37,10 +37,11 @@ static const char kNoSharingResourceID[] = Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + const char* raw_device_name, ImmediateTensorHandlePtr* handle) { ImmediateOpPtr varhandle_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr)); + TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name)); TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype)); // Note that if shape is unknown rank, shape.dim_sizes() will be empty, and diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index 13c941a77fe..accad1591da 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -31,6 +31,7 @@ namespace internal { // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + const char* raw_device_name, ImmediateTensorHandlePtr* handle); // Executes an AssignVariableOp using `ctx`, assigning the variable associated diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 55a4a32e983..5ce027fe6d8 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -55,7 +55,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &handle)); + context(), DT_FLOAT, {}, nullptr, &handle)); // The created TensorHandle should be a DT_Resource EXPECT_EQ(handle->DataType(), DT_RESOURCE); } @@ -65,7 +65,7 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &handle)); + context(), DT_FLOAT, {}, nullptr, &handle)); // Destroy the variable TF_EXPECT_OK(internal::DestroyResource(context(), handle.get())); @@ -76,7 +76,7 @@ TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr variable; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &variable)); + context(), DT_FLOAT, {}, nullptr, &variable)); // Create a Scalar float TensorHandle with value 42, and assign it to // the variable. diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 2b883618c87..25cac39daa0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -28,6 +28,26 @@ cc_library( ], ) +cc_library( + name = "flat_tensor_function", + srcs = [ + "flat_tensor_function.cc", + ], + hdrs = [ + "flat_tensor_function.h", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "variable", srcs = [ @@ -68,7 +88,7 @@ cc_library( "tf_concrete_function.h", ], deps = [ - ":tensorhandle_convertible", + ":flat_tensor_function", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", @@ -81,3 +101,26 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tf_signature_def_function", + srcs = [ + "tf_signature_def_function.cc", + ], + hdrs = [ + "tf_signature_def_function.h", + ], + deps = [ + ":flat_tensor_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc new file mode 100644 index 00000000000..ad9f896f43d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +FlatTensorFunction::FlatTensorFunction( + const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx) + : name_(name), captures_(std::move(captures)), ctx_(ctx) {} + +FlatTensorFunction::~FlatTensorFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status FlatTensorFunction::Create( + const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + out->reset(new FlatTensorFunction(function_def->signature().name(), + std::move(captures), ctx)); + return Status(); +} + +Status FlatTensorFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + + // Adding the user-provided inputs to the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); + + absl::Span captures( + reinterpret_cast(captures_.data()), + captures_.size()); + + // Adding the captures of the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h new file mode 100644 index 00000000000..e6bcdec7e3a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FlatTensorFunction models a TF2 eager runtime view of a callable function, +// taking + returning flat lists of tensors, including any captures. +// Effectively, it is a thin wrapper around a FunctionDef owned by the +// EagerContext, and any TensorHandle captures associated with the function. The +// MakeCallOp method handles the logic of marshaling captures after the user +// provided inputs automatically. +// Note(bmzhao): This class is mainly intended to house low-level reusable +// function logic between SignatureDefFunction and ConcreteFunction, which +// present higher level interfaces. This type does *not* hold any "function +// metadata". +class FlatTensorFunction { + public: + // Factory for creating a FlatTensorFunction. + // + // Params: + // function_def - The function_def associated with the created + // FlatTensorFunction. FlatTensorFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // FlatTensorFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output FlatTensorFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const; + + ~FlatTensorFunction(); + + private: + FlatTensorFunction(const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx); + + FlatTensorFunction(const FlatTensorFunction&) = delete; + FlatTensorFunction& operator=(const FlatTensorFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc index f734f9eca66..d9773a4520f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/errors.h" @@ -33,32 +33,20 @@ limitations under the License. namespace tensorflow { -TFConcreteFunction::TFConcreteFunction( - const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFConcreteFunction::~TFConcreteFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} +TFConcreteFunction::TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} Status TFConcreteFunction::Create( const FunctionDef* function_def, std::vector captures, FunctionMetadata metadata, ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFConcreteFunction(function_def->signature().name(), - std::move(captures), std::move(metadata), - ctx)); + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFConcreteFunction(std::move(func), std::move(metadata))); return Status(); } @@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { return metadata_; } -Status TFConcreteFunction::GetCallOp( - absl::Span inputs, ImmediateOpPtr* out) { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - - // Adding the user-provided inputs to the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); - - absl::Span captures( - reinterpret_cast(captures_.data()), - captures_.size()); - - // Adding the captures of the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); - return Status(); +Status TFConcreteFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h index d38f3546f91..edc26f4d5aa 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction { std::unique_ptr* out); // This method returns the "Call" Op used to execute the function. - Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) override; + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; const FunctionMetadata& GetFunctionMetadata() const override; - ~TFConcreteFunction() override; + ~TFConcreteFunction() override = default; private: - TFConcreteFunction(const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx); + TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata); TFConcreteFunction(const TFConcreteFunction&) = delete; TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; - // Name of the FunctionDef corresponding to this TFConcreteFunction - std::string name_; - std::vector captures_; + std::unique_ptr func_; FunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc new file mode 100644 index 00000000000..ab1745dcd47 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +TFSignatureDefFunction::TFSignatureDefFunction( + std::unique_ptr func, + SignatureDefFunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} + +Status TFSignatureDefFunction::Create( + const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out) { + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata))); + return Status(); +} + +const SignatureDefFunctionMetadata& +TFSignatureDefFunction::GetFunctionMetadata() const { + return metadata_; +} + +Status TFSignatureDefFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h new file mode 100644 index 00000000000..7b564185b8b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// This is the TF eager runtime implementation of SignatureDefFunction (separate +// from the TFRT implementation). The user-facing API of SignatureDefFunctions +// and their semantic differences from ConcreteFunction are described here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59 +// Additional implementation notes are available here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48 +class TFSignatureDefFunction : public SignatureDefFunction { + public: + // Factory function for creating a TFSignatureDefFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFSignatureDefFunction. TFSignatureDefFunction will + // register this function_def with `ctx` on creation, and + // de-register it on destruction. function_def must be + // non-null, but otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - FunctionMetadata associated with this TFSignatureDefFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFSignatureDefFunction. + // out - The output TFSignatureDefFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; + + const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; + + ~TFSignatureDefFunction() override = default; + + private: + TFSignatureDefFunction(std::unique_ptr func, + SignatureDefFunctionMetadata metadata); + + TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; + TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; + + std::unique_ptr func_; + SignatureDefFunctionMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc index d831a8dd840..a212c25bd28 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc @@ -65,10 +65,11 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) { Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, absl::optional name, + const char* raw_device_name, std::unique_ptr* output) { ImmediateTensorHandlePtr handle; TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( - ctx, dtype, shape, &handle)); + ctx, dtype, shape, raw_device_name, &handle)); output->reset( new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h index 48ea1d08862..13f56fda5f3 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -37,6 +37,7 @@ class Variable : public TensorHandleConvertible { static Status CreateUninitialized(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, absl::optional name, + const char* raw_device_name, std::unique_ptr* output); // The dtype of the underlying variable. diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_api.h b/tensorflow/c/experimental/saved_model/core/saved_model_api.h index 5d0ed63a765..ff891e13ba4 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_api.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -39,11 +40,11 @@ class SavedModelAPI { virtual Status GetFunction(const std::string& function_path, ConcreteFunction** function) = 0; - // Retrieve a function from a SavedModel, using the key of the + // Retrieve a SignatureDefFunction from a SavedModel, using the key of the // SignatureDef map: // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 virtual Status GetSignatureDefFunction(const std::string& signature_def_key, - ConcreteFunction** function) = 0; + SignatureDefFunction** function) = 0; virtual std::vector ListFunctions() = 0; diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 2037c4886de..e79fd8d7001 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" @@ -36,52 +37,8 @@ namespace tensorflow { namespace internal { namespace { -// This returns the size of `tf.nest.flatten(value)`, on values that are -// used in tf.function's input_signatures. -int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) { - // This follows the logic from - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 - switch (value.kind_case()) { - case StructuredValue::kDictValue: { - const DictValue& dict = value.dict_value(); - int size = 0; - for (const auto& field : dict.fields()) { - size += FlattenedSize(field.second, status); - } - return size; - } - case StructuredValue::kTupleValue: { - const TupleValue& tuple = value.tuple_value(); - int size = 0; - for (const StructuredValue& value : tuple.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kListValue: { - const ListValue& list = value.list_value(); - int size = 0; - for (const StructuredValue& value : list.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kTensorSpecValue: { - return 1; - } - case StructuredValue::kNoneValue: { - // Base case: do nothing. - // This arises, for example, as the top-level object of an output - // signature when there are no return values. - return 0; - } - default: { - status->Update(errors::Internal("Unhandled structured value kind ", - value.kind_case())); - return 0; - } - } -} +using StructuredValueDictEntry = + protobuf::MapPair; // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input @@ -111,34 +68,34 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 const std::string& name = function_def->signature().name(); + const StructuredValue& input_signature = saved_concrete_function.canonicalized_input_signature(); - Status status; - int input_signature_size = FlattenedSize(input_signature, &status); - TF_RETURN_IF_ERROR(status); - if (input_signature_size + saved_concrete_function.bound_inputs_size() != + std::vector input_specs; + TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs)); + if (input_specs.size() + saved_concrete_function.bound_inputs_size() != function_def->signature().input_arg_size()) { return errors::FailedPrecondition( "FunctionDef ", name, " has ", function_def->signature().input_arg_size(), - " inputs, but the SavedConcreteFunction has ", input_signature_size, + " inputs, but the SavedConcreteFunction has ", input_specs.size(), " flattened user inputs and ", saved_concrete_function.bound_inputs_size(), " captured inputs."); } const StructuredValue& output_signature = saved_concrete_function.output_signature(); - int output_signature_size = FlattenedSize(output_signature, &status); - TF_RETURN_IF_ERROR(status); - if (output_signature_size != function_def->signature().output_arg_size()) { + std::vector output_specs; + TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs)); + if (output_specs.size() != function_def->signature().output_arg_size()) { return errors::FailedPrecondition( "FunctionDef ", name, " has ", function_def->signature().output_arg_size(), - " outputs, but the SavedConcreteFunction has ", output_signature_size, + " outputs, but the SavedConcreteFunction has ", output_specs.size(), " flattened outputs."); } - return status; + return Status(); } } // namespace @@ -165,9 +122,9 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, tensorflow::TensorShape shape(variable.shape()); tensorflow::DataType dtype = variable.dtype(); - TF_RETURN_IF_ERROR( - Variable::CreateUninitialized(ctx, dtype, shape, name, output)); - + TF_RETURN_IF_ERROR(Variable::CreateUninitialized( + ctx, dtype, shape, name, + variable.device().empty() ? nullptr : variable.device().c_str(), output)); return Status(); } @@ -197,6 +154,62 @@ Status LoadTFConcreteFunction( out); } +Status FlattenSignature(const StructuredValue& signature, + std::vector* flattened_specs) { + // This follows the logic from + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 + switch (signature.kind_case()) { + case StructuredValue::kDictValue: { + // Dictionaries must be sorted in order of keys + const DictValue& dict = signature.dict_value(); + std::vector entries; + entries.reserve(dict.fields_size()); + for (const auto& field : dict.fields()) { + entries.push_back(&field); + } + + std::sort(entries.begin(), entries.end(), + [](const StructuredValueDictEntry* x, + const StructuredValueDictEntry* y) { + return x->first < y->first; + }); + + for (const auto& entry : entries) { + TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs)); + } + return Status(); + } + case StructuredValue::kTupleValue: { + const TupleValue& tuple = signature.tuple_value(); + for (const StructuredValue& value : tuple.values()) { + TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs)); + } + return Status(); + } + case StructuredValue::kListValue: { + const ListValue& list = signature.list_value(); + for (const StructuredValue& value : list.values()) { + TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs)); + } + return Status(); + } + case StructuredValue::kTensorSpecValue: { + flattened_specs->push_back(&signature.tensor_spec_value()); + return Status(); + } + case StructuredValue::kNoneValue: { + // Base case: do nothing. + // This arises, for example, as the top-level object of an output + // signature when there are no return values. + return Status(); + } + default: { + return errors::Internal("Unhandled structured value kind ", + signature.kind_case()); + } + } +} + const SavedObject* FindNodeAtPath(StringPiece path, const SavedObjectGraph& object_graph) { const auto& nodes = object_graph.nodes(); diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 57f30afa91b..68bfbe32222 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace internal { @@ -59,10 +60,17 @@ Status LoadTFConcreteFunction( captured_objects, ImmediateExecutionContext* ctx, std::unique_ptr* out); -// Find the SavedObject in `object_graph` at location `path`. `path` must be a -// dot-delimited string of object names relative to the root object. If no -// object is found, returns nullptr. Callers must ensure `object_graph` outlives -// the returned pointer. +// Flattens `signature` into a vector of TensorSpecProto pointers back into +// `signature`. `signature` must outlive flattened_specs. `signature` must also +// be the input or output signature of a SavedConcreteFunction (i.e. "nested +// structures of tensorspecs"). +Status FlattenSignature(const StructuredValue& signature, + std::vector* flattened_specs); + +// Find the SavedObject in `object_graph` at location `path`. `path` must be +// a dot-delimited string of object names relative to the root object. If no +// object is found, returns nullptr. Callers must ensure `object_graph` +// outlives the returned pointer. const SavedObject* FindNodeAtPath(StringPiece path, const SavedObjectGraph& object_graph); diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc index cf58e5e3536..45b0ac00c9b 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -38,9 +39,15 @@ namespace { class SavedVariableLoadingTest : public ::testing::TestWithParam< std::tuple>> { public: - SavedVariableLoadingTest() - : device_mgr_(testing::CreateTestingDeviceMgr()), - ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + SavedVariableLoadingTest() { + SessionOptions options; + options.config.mutable_device_count()->insert({"CPU", 3}); + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + ctx_ = testing::CreateTestingEagerContext(device_mgr_.get()); + } EagerContext* context() { return ctx_.get(); } @@ -67,6 +74,39 @@ TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) { EXPECT_EQ(var->shape(), shape); } +// Verify that a device specified in the SavedVariable is kept. +TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithDevice) { + auto& test_params = GetParam(); + DataType dtype = std::get<0>(test_params); + TensorShape shape(std::get<1>(test_params)); + + SavedVariable saved_variable; + saved_variable.set_dtype(dtype); + saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:1"), + shape.AsProto(saved_variable.mutable_shape()); + + std::unique_ptr var; + TF_ASSERT_OK(internal::LoadSavedVariable(context(), saved_variable, &var)); + EXPECT_EQ(down_cast(var->handle())->resource_device()->name(), + "/job:localhost/replica:0/task:0/device:CPU:1"); +} + +// Verify load failure if a non-existing device is specified. +TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithInvalidDevice) { + auto& test_params = GetParam(); + DataType dtype = std::get<0>(test_params); + TensorShape shape(std::get<1>(test_params)); + + SavedVariable saved_variable; + saved_variable.set_dtype(dtype); + saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:99"), + shape.AsProto(saved_variable.mutable_shape()); + + std::unique_ptr var; + ASSERT_NE(Status::OK(), + internal::LoadSavedVariable(context(), saved_variable, &var)); +} + // Assigning and reading values should yield // consistent results. TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { @@ -79,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { Status status; std::unique_ptr var; TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape, - absl::nullopt, &var)); + absl::nullopt, nullptr, &var)); // Create a TensorHandle ImmediateTensorHandlePtr expected_handle = diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function.h b/tensorflow/c/experimental/saved_model/core/signature_def_function.h new file mode 100644 index 00000000000..0a217f3cc21 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +// See tensorflow/cc/experimental/saved_model/public/signature_def_function.h +// for SignatureDefFunction's intended user-facing semantics. +// This class is the "implementation" C++ part of the C++/C/C++ sandwich for +// a SignatureDefFunction. +// Note(bmzhao): Implementation-wise, SignatureDefFunctions are always saved as +// a "BareConcreteFunction", w/o a FunctionSpec, rather than a SavedFunction: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/saved_object_graph.proto#L60 +// Additionally they are guaranteed to be children of the .signatures attribute +// of the root object, where the child object "name" is the signature_def key: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/python/saved_model/signature_serialization.py#L181-L230 +// One of the critical requirements of SignatureDef functions is that their +// inputs and outputs are "named". For example, a `.signatures` function: +// a. Requires users to pass: kwargs of all inputs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// b. Returns a dictionary of named outputs. +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L153-L161 +// Since SignatureDefFunctions do not have FunctionSpecs, but guarantee the +// dictionary of inputs/outputs, we can parse these dictionaries' keys to obtain +// the input/output names of the SignatureDef: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/meta_graph.proto#L318-L321 +class SignatureDefFunction { + public: + virtual ~SignatureDefFunction() = default; + + // Creates a "Call" Op used to execute the function. + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; + + virtual const SignatureDefFunctionMetadata& GetFunctionMetadata() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h new file mode 100644 index 00000000000..5a579676d4e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +namespace tensorflow { + +class SignatureDefFunctionMetadata { + // TODO(bmzhao): Fill in with fields as necessary +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc b/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc new file mode 100644 index 00000000000..9ee495f524a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace { + +// Validates names, shapes, and dtypes of two tensorspecprotos are equivalent. +bool TensorSpecsAreEqual(const TensorSpecProto& spec, + const std::string& expected_name, + const PartialTensorShape& expected_shape, + DataType expected_dtype) { + return spec.name() == expected_name && + PartialTensorShape(spec.shape()).IsIdenticalTo(expected_shape) && + spec.dtype() == expected_dtype; +} + +// This tests the common case for a tf.function w/o inputs. This ends up +// being serialized as a tuple of an empty tuple + empty dictionary +// (corresponding to the args, kwargs) of the function. +TEST(SignatureFlatteningTest, ZeroArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::ZeroArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 0); +} + +// This tests the common case for a tf.function w/o outputs. This ends up +// being serialized as a "NoneValue". +TEST(SignatureFlatteningTest, ZeroRetOutputSignature) { + std::vector flattened; + StructuredValue value = testing::ZeroReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 0); +} + +TEST(SignatureFlatteningTest, SingleArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::SingleArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 1); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "x", + /* expected_shape = */ {1, 10}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); +} + +TEST(SignatureFlatteningTest, SingleReturnOutputSignature) { + std::vector flattened; + StructuredValue value = testing::SingleReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 1); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); +} + +TEST(SignatureFlatteningTest, ThreeArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::ThreeArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 3); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "x", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1], + /* expected_name = */ "y", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[1]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2], + /* expected_name = */ "z", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[2]->DebugString(); +} + +// This test has an exotic outputsignature of tuple of a +// dictionary, tensor +TEST(SignatureFlatteningTest, ThreeReturnOutputSignature) { + std::vector flattened; + StructuredValue value = testing::ThreeReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 3); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "0/a", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1], + /* expected_name = */ "0/b", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[1]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2], + /* expected_name = */ "1", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[2]->DebugString(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index b803d129b90..7c11158b17d 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +45,6 @@ EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) { return EagerContextPtr(new EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false, /* lazy_copy_function_remote_inputs= */ false, device_mgr, /* device_mgr_owned= */ false, /* rendezvous= */ nullptr, diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index c22f8d86174..ab7052b52ed 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -47,6 +48,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/stringpiece.h" @@ -241,8 +243,11 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, // TODO(bmzhao): This requires using the newly added Save/Restore // functions from // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c - return errors::Unimplemented( - "Restoring non-variable objects has not been implemented yet. "); + LOG(WARNING) << "Restoring non-variable objects has not been " + "implemented yet. (Kind=" + << bundle->saved_object_graph().nodes(node).kind_case() + << ")"; + return Status::OK(); } Variable* variable = @@ -301,7 +306,7 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, } Status TFSavedModelAPI::GetSignatureDefFunction( - const std::string& signature_def_key, ConcreteFunction** function) { + const std::string& signature_def_key, SignatureDefFunction** function) { // TODO(bmzhao): Add support for retrieving a signaturedef function. return errors::Unimplemented( "Retrieving SignatureDef functions is unimplemented currently"); diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index fc8e738e86f..fd07c09474b 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/core/platform/status.h" @@ -55,7 +56,7 @@ class TFSavedModelAPI : public SavedModelAPI { ConcreteFunction** function) override; Status GetSignatureDefFunction(const std::string& signature_def_key, - ConcreteFunction** function) override; + SignatureDefFunction** function) override; static Status Load( const std::string& directory, diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 323298c5fc1..c0d121a4aee 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -142,6 +142,8 @@ cc_library( ":concrete_function_list_type", ":concrete_function_type", ":saved_model_api_type", + ":signature_def_function", + ":signature_def_function_type", "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_internal", @@ -165,6 +167,77 @@ cc_library( ], ) +cc_library( + name = "signature_def_function", + srcs = [ + "signature_def_function.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_function_metadata", + ":signature_def_function_metadata_type", + ":signature_def_function_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "signature_def_function_type", + hdrs = [ + "signature_def_function_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + ], +) + +cc_library( + name = "signature_def_function_metadata", + srcs = [ + "signature_def_function_metadata.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_function_metadata_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_function_metadata_type", + hdrs = [ + "signature_def_function_metadata_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 65c6eca5623..2beed8f4119 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, - TFE_TensorHandle** inputs, int num_inputs, - TF_Status* status) { +TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func, + TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status) { tensorflow::ImmediateOpPtr call_op; absl::Span input_span( reinterpret_cast( tensorflow::unwrap(inputs)), static_cast(num_inputs)); - status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 983c98affb2..b89fb9f6d64 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -106,9 +107,11 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, return tensorflow::wrap(result); } -TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { - tensorflow::ConcreteFunction* result = nullptr; +TF_CAPI_EXPORT extern TF_SignatureDefFunction* +TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, + const char* signature_def_key, + TF_Status* status) { + tensorflow::SignatureDefFunction* result = nullptr; tensorflow::Status get_function_status = tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, &result); diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index e58b232f9c9..df998fcf6cd 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { compute_fn_inputs.push_back(input_a); compute_fn_inputs.push_back(input_b); - TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( + TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp( compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc new file mode 100644 index 00000000000..64f7506f32e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/platform/status.h" + +extern "C" { + +TF_SignatureDefFunctionMetadata* TF_SignatureDefFunctionGetMetadata( + TF_SignatureDefFunction* func) { + return tensorflow::wrap(const_cast( + &tensorflow::unwrap(func)->GetFunctionMetadata())); +} + +TFE_Op* TF_SignatureDefFunctionMakeCallOp(TF_SignatureDefFunction* func, + TFE_TensorHandle** inputs, + int num_inputs, TF_Status* status) { + tensorflow::ImmediateOpPtr call_op; + absl::Span input_span( + reinterpret_cast( + tensorflow::unwrap(inputs)), + static_cast(num_inputs)); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); + if (!status->status.ok()) { + return nullptr; + } + return tensorflow::wrap(call_op.release()); +} + +} // end extern "C" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/register_all_passes.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc similarity index 73% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/register_all_passes.cc rename to tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc index 9349bee041e..c5c3616211c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/register_all_passes.cc +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc @@ -13,16 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" -namespace mlir { +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" -namespace { - -bool register_all_passes = ([] { - mhlo::registerAllMhloPasses(); - lmhlo::registerAllLmhloPasses(); -}(), true); - -} // namespace -} // namespace mlir +// TODO(bmzhao): Add getter functions here as necessary. diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h new file mode 100644 index 00000000000..fa6d0f6541e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunctionMetadata, + TF_SignatureDefFunctionMetadata) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h new file mode 100644 index 00000000000..ca44dc43bd6 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" + +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunction, + TF_SignatureDefFunction) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index af65e05e7f6..d29585ae1ba 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -24,6 +24,8 @@ exports_files( "concrete_function_list.h", "function_metadata.h", "saved_model_api.h", + "signature_def_function.h", + "signature_def_function_metadata.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -39,6 +41,8 @@ cc_library( ":concrete_function_list", ":function_metadata", ":saved_model_api", + ":signature_def_function", + ":signature_def_function_metadata", ], ) @@ -61,3 +65,13 @@ alias( name = "saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", ) + +alias( + name = "signature_def_function", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function", +) + +alias( + name = "signature_def_function_metadata", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata", +) diff --git a/tensorflow/c/experimental/saved_model/public/README.md b/tensorflow/c/experimental/saved_model/public/README.md new file mode 100644 index 00000000000..9b3f392d7a8 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/README.md @@ -0,0 +1,28 @@ +# TensorFlow Saved Model C API + +## Small ConcreteFunction Example + +The following example loads a saved model from `"/path/to/model"` and +executes a function `f` taking no arguments and returning one single +value (error checking is omitted for simplicity): + +```c +TF_Status* status = TF_NewStatus(); +TFE_ContextOptions* ctx_options = TFE_NewContextOptions(); +TFE_Context* ctx = TFE_NewContext(ctx_options, status); + +TF_SavedModel* saved_model = TF_LoadSavedModel("/path/to/model", ctx, status); +TF_ConcreteFunction* f = TF_GetSavedModelConcreteFunction(saved_model, "f", status); +TFE_Op* op = TF_ConcreteFunctionMakeCallOp(f, NULL, 0, status); + +TFE_TensorHandle* output; +int nouts = 1; +TFE_Execute(op, &output, &nouts, status); + +TFE_DeleteTensorHandle(output); +TFE_DeleteOp(op); +TFE_DeleteSavedModel(saved_model); +TFE_DeleteContext(ctx); +TFE_DeleteContextOptions(ctx_options); +TF_DeleteStatus(status); +``` diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index 30f533f140a..cedb9de66b8 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index ee5292294d6..ff8a245961a 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -40,7 +40,14 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( // The caller is responsible for deleting the returned TFE_Op. If op // construction fails, `status` will be non-OK and the returned pointer will be // null. -TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( +// TODO(bmzhao): Remove this function in a subsequent change; Design + implement +// a Function Execution interface for ConcreteFunction that accepts a tagged +// union of types (tensorflow::Value). This effectively requires moving much of +// the implementation of function.py/def_function.py to C++, and exposing a +// high-level API here. A strawman for what this interface could look like: +// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* +// inputs, int num_inputs, TF_Status* status); +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp( TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status); diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h index 875167bec63..80ba37bab26 100644 --- a/tensorflow/c/experimental/saved_model/public/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" #include "tensorflow/c/tf_status.h" #ifdef __cplusplus @@ -91,10 +92,13 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( // status - Set to OK on success and an appropriate error on failure. // Returns: // If status is not OK, returns nullptr. Otherwise, returns a -// TF_ConcreteFunction instance. Once `model` is deleted, all -// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. -TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, const char* signature_def_key, TF_Status* status); +// TF_SignatureDefFunction instance. Once `model` is deleted, all +// `TF_SignatureDefFunctions` retrieved from it are invalid, and have been +// deleted. +TF_CAPI_EXPORT extern TF_SignatureDefFunction* +TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, + const char* signature_def_key, + TF_Status* status); // Returns a list of all ConcreteFunctions stored in this SavedModel. // The lifetime of the returned list is bound to `model`. diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function.h b/tensorflow/c/experimental/saved_model/public/signature_def_function.h new file mode 100644 index 00000000000..16471fdc1fa --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +// Returns FunctionMetadata associated with `func`. Metadata's lifetime is +// bound to `func`, which is bound to the TF_SavedModel it was loaded from. +TF_CAPI_EXPORT extern TF_SignatureDefFunctionMetadata* +TF_SignatureDefFunctionGetMetadata(TF_SignatureDefFunction* func); + +// Returns a TFE_Op suitable for executing this function. Caller must provide +// all function inputs in `inputs`, and must not add any additional inputs on +// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList). +// The caller is responsible for deleting the returned TFE_Op. If op +// construction fails, `status` will be non-OK and the returned pointer will be +// null. +TF_CAPI_EXPORT extern TFE_Op* TF_SignatureDefFunctionMakeCallOp( + TF_SignatureDefFunction* func, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h new file mode 100644 index 00000000000..6f4459732c4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD new file mode 100644 index 00000000000..7daa311d461 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -0,0 +1,60 @@ +# Description: +# StreamExecutor C API. + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "stream_executor", + srcs = ["stream_executor.cc"], + hdrs = ["stream_executor.h"], + visibility = ["//visibility:public"], + deps = [ + ":stream_executor_internal", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:executor_cache", + "//tensorflow/stream_executor:multi_platform_manager", + "//tensorflow/stream_executor:platform", + "//tensorflow/stream_executor:stream_executor_internal", + "//tensorflow/stream_executor:stream_executor_pimpl", + "//tensorflow/stream_executor:timer", + ], +) + +cc_library( + name = "stream_executor_internal", + hdrs = [ + "stream_executor.h", + "stream_executor_internal.h", + ], + deps = [ + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/stream_executor:executor_cache", + "//tensorflow/stream_executor/lib", + ], +) + +tf_cc_test( + name = "stream_executor_test", + srcs = ["stream_executor_test.cc"], + deps = [ + ":stream_executor", + ":stream_executor_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "//tensorflow/stream_executor:multi_platform_manager", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor:stream_executor_pimpl", + ], +) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc new file mode 100644 index 00000000000..901ef942305 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -0,0 +1,853 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file extends/implements core stream executor base classes in terms of +// the C API defined in stream_executor.h. A class "CSomething" represents a +// "Something" that can be manipulated via calls in the C interface and a C +// struct called "SP_Something". +// +// This file also contains stream_executor::Platform registration for pluggable +// device. +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +#include + +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/timer.h" + +using tensorflow::StatusFromTF_Status; + +namespace stream_executor { +namespace { + +#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ + do { \ + if (STRUCT_OBJ.struct_size == 0) { \ + return port::FailedPreconditionError( \ + "struct_size field in " #STRUCT_NAME \ + " must be set to " #SIZE_VALUE_NAME "."); \ + } \ + } while (0) + +#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \ + do { \ + if (STRUCT_OBJ.NAME == 0) { \ + return port::FailedPreconditionError( \ + "'" #NAME "' field in " #STRUCT_NAME " must be set."); \ + } \ + } while (0) + +port::Status ValidateSPPlatform(const SP_Platform& platform) { + VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); + VALIDATE_MEMBER(SP_Platform, platform, name); + VALIDATE_MEMBER(SP_Platform, platform, type); + // `visible_device_count` could be 0 at initialization time. + return port::Status::OK(); +} + +port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { + VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns, + SP_PLATFORM_FNS_STRUCT_SIZE); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns); + VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns); + return port::Status::OK(); +} + +port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) { + VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE); + VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds); + return port::Status::OK(); +} + +port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) { + VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) { + VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem, + SP_DEVICE_MEMORY_BASE_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPDevice(const SP_Device& device) { + VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, + const SP_Platform& platform) { + VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE); + VALIDATE_MEMBER(SP_StreamExecutor, se, allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate); + if (platform.supports_unified_memory) { + VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate); + } + VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status); + VALIDATE_MEMBER(SP_StreamExecutor, se, record_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh); + VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod); + VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh); + VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod); + VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback); + return port::Status::OK(); +} + +port::Status ValidateSEPlatformRegistrationParams( + const SE_PlatformRegistrationParams& params) { + VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params, + SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE); + VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform); + VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns); + return port::Status::OK(); +} +#undef VALIDATE_MEMBER + +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; +using OwnedTFStatus = std::unique_ptr; + +class CStream : public internal::StreamInterface { + public: + CStream(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + stream_handle_(nullptr) {} + ~CStream() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + return s; + } + + void Destroy() { + if (stream_handle_ != nullptr) { + stream_executor_->destroy_stream(device_, stream_handle_); + stream_handle_ = nullptr; + } + } + + SP_Stream Handle() { return stream_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Stream stream_handle_; +}; + +// Converts SE_EventStatus to Event::Status. +Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { + switch (s) { + case SE_EVENT_ERROR: + return Event::Status::kError; + case SE_EVENT_PENDING: + return Event::Status::kPending; + case SE_EVENT_COMPLETE: + return Event::Status::kComplete; + default: + return Event::Status::kUnknown; + } +} + +class CEvent : public internal::EventInterface { + public: + CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + event_handle_(nullptr) {} + ~CEvent() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_event(device_, &event_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + port::Status Record(SP_Stream stream_handle) { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->record_event(device_, stream_handle, event_handle_, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (event_handle_ != nullptr) { + stream_executor_->destroy_event(device_, event_handle_); + event_handle_ = nullptr; + } + } + + SP_Event Handle() { return event_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Event event_handle_; +}; + +class CTimer : public internal::TimerInterface { + public: + CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, + SP_TimerFns* timer_fns) + : device_(device), + stream_executor_(stream_executor), + timer_handle_(nullptr), + timer_fns_(timer_fns) {} + ~CTimer() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (timer_handle_ != nullptr) { + stream_executor_->destroy_timer(device_, timer_handle_); + timer_handle_ = nullptr; + } + } + + SP_Timer Handle() { return timer_handle_; } + + uint64 Microseconds() const override { + return timer_fns_->nanoseconds(timer_handle_) / 1000; + } + + uint64 Nanoseconds() const override { + return timer_fns_->nanoseconds(timer_handle_); + } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Timer timer_handle_; + SP_TimerFns* timer_fns_; +}; + +// Converts DeviceMemoryBase to a C struct. +SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { + SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + // `opaque` field inside SP_DeviceMemoryBase is not const. + // Therefore, we need to cast away the constness before setting it. + device_memory_base.opaque = const_cast(mem->opaque()); + device_memory_base.size = mem->size(); + device_memory_base.payload = mem->payload(); + // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. + return device_memory_base; +} + +DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) { + DeviceMemoryBase base(mem.opaque, mem.size); + base.SetPayload(mem.payload); + // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. + return base; +} + +// Wrapper that allows passing std::function across C API. +struct HostCallbackContext { + std::function callback; +}; + +// This wrapper allows calling `HostCallbackContext::callback` across C API. +// This function matches `SE_StatusCallbackFn` signature and will be passed as +// `callback_fn` to `host_callback` in `SP_StreamExecutor`. +void HostCallbackTrampoline(void* ctx, TF_Status* status) { + HostCallbackContext* host_ctx = static_cast(ctx); + port::Status s = host_ctx->callback(); + Set_TF_Status_from_Status(status, s); + delete host_ctx; +} + +class CStreamExecutor : public internal::StreamExecutorInterface { + public: + explicit CStreamExecutor(SP_Device device, SP_StreamExecutor* stream_executor, + SP_Platform* platform, SP_PlatformFns* platform_fns, + SP_TimerFns* timer_fns, const std::string& name, + int visible_device_count) + : device_(std::move(device)), + stream_executor_(stream_executor), + platform_(platform), + platform_fns_(platform_fns), + timer_fns_(timer_fns), + platform_name_(name), + visible_device_count_(visible_device_count) {} + + ~CStreamExecutor() override { + platform_fns_->destroy_device(platform_, &device_); + } + + port::Status Init(int device_ordinal, DeviceOptions device_options) override { + return port::Status::OK(); + } + + DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override { + SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + stream_executor_->allocate(&device_, size, memory_space, &mem); + port::Status status = ValidateSPDeviceMemoryBase(mem); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } + return DeviceMemoryBaseFromC(mem); + } + DeviceMemoryBase Allocate(uint64 size) { + return Allocate(size, /*memory_space=*/0); + } + void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, + uint64 size) override { + LOG(FATAL) << "GetSubBuffer is not supported by pluggable device."; + } + + void Deallocate(DeviceMemoryBase* mem) override { + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem); + stream_executor_->deallocate(&device_, &device_memory_base); + } + + void* HostMemoryAllocate(uint64 size) override { + return stream_executor_->host_memory_allocate(&device_, size); + } + + void HostMemoryDeallocate(void* mem) override { + stream_executor_->host_memory_deallocate(&device_, mem); + } + + bool HostMemoryRegister(void* mem, uint64 size) override { return false; } + bool HostMemoryUnregister(void* mem) override { return false; } + + void* UnifiedMemoryAllocate(uint64 size) override { + CHECK(stream_executor_->unified_memory_allocate); + return stream_executor_->unified_memory_allocate(&device_, size); + } + + void UnifiedMemoryDeallocate(void* mem) override { + CHECK(stream_executor_->unified_memory_deallocate); + stream_executor_->unified_memory_deallocate(&device_, mem); + } + + absl::optional GetAllocatorStats() override { + SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE}; + TF_Bool has_stats = + stream_executor_->get_allocator_stats(&device_, &c_stats); + if (!has_stats) { + return absl::nullopt; + } + port::Status status = ValidateSPAllocatorStats(c_stats); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + return absl::nullopt; + } + // TODO(annarev): validate SP_AllocatorStats. + ::stream_executor::AllocatorStats stats; + stats.num_allocs = c_stats.num_allocs; + stats.bytes_in_use = c_stats.bytes_in_use; + stats.peak_bytes_in_use = c_stats.peak_bytes_in_use; + stats.largest_alloc_size = c_stats.largest_alloc_size; + if (c_stats.has_bytes_limit) { + stats.bytes_limit = c_stats.bytes_limit; + } + stats.bytes_reserved = c_stats.bytes_reserved; + stats.peak_bytes_reserved = c_stats.peak_bytes_reserved; + if (c_stats.has_bytes_reservable_limit) { + stats.bytes_reservable_limit = c_stats.bytes_reservable_limit; + } + stats.largest_free_block_bytes = c_stats.largest_free_block_bytes; + return stats; + } + bool SynchronizeAllActivity() override { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->synchronize_all_activity(&device_, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + port::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64 size) override { + // TODO(annarev): figure out if we should support memzero/memset + // functionality by allocating on host and then copying to device. + return port::UnimplementedError( + "SynchronousMemZero is not supported by pluggable device."); + } + port::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64 size) override { + return port::UnimplementedError( + "SynchronousMemSet is not supported by pluggable device."); + } + port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src, + size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base, + size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst, + &device_mem_src, size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64 size) override { + return port::UnimplementedError( + "MemZero is not supported by pluggable device."); + } + port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, + uint64 size) override { + return port::UnimplementedError( + "Memset is not supported by pluggable device."); + } + port::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32 pattern, uint64 size) override { + return port::UnimplementedError( + "Memset32 is not supported by pluggable device."); + } + bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst, + host_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool HostCallback(Stream* stream, + std::function callback) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + HostCallbackContext* ctx = new HostCallbackContext{callback}; + return stream_executor_->host_callback(&device_, stream_handle, + &HostCallbackTrampoline, ctx); + } + port::Status AllocateEvent(Event* event) override { + DCHECK(event != nullptr); + return static_cast(event->implementation())->Create(); + } + port::Status DeallocateEvent(Event* event) override { + static_cast(event->implementation())->Destroy(); + return port::Status::OK(); + } + port::Status RecordEvent(Stream* stream, Event* event) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + return static_cast(event->implementation())->Record(stream_handle); + } + port::Status WaitForEvent(Stream* stream, Event* event) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->wait_for_event(&device_, stream_handle, event_handle, + c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + return s; + } + Event::Status PollForEventStatus(Event* event) override { + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + SE_EventStatus event_status = + stream_executor_->get_event_status(&device_, event_handle); + return SEEventStatusToEventStatus(event_status); + } + bool AllocateStream(Stream* stream) override { + DCHECK(stream != nullptr); + port::Status status = + static_cast(stream->implementation())->Create(); + // TODO(annarev): update AllocateStream to return status instead + // (similar to AllocateEvent). + return status.ok(); + } + void DeallocateStream(Stream* stream) override { + static_cast(stream->implementation())->Destroy(); + } + bool CreateStreamDependency(Stream* dependent, Stream* other) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream dependent_handle = + static_cast(dependent->implementation())->Handle(); + SP_Stream other_handle = + static_cast(other->implementation())->Handle(); + stream_executor_->create_stream_dependency(&device_, dependent_handle, + other_handle, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool AllocateTimer(Timer* timer) override { + port::Status status = + static_cast(timer->implementation())->Create(); + // TODO(annarev): change return value of AllocateTimer + // to status (similar to AllocateEvent). + return status.ok(); + } + void DeallocateTimer(Timer* timer) override { + static_cast(timer->implementation())->Destroy(); + } + bool StartTimer(Stream* stream, Timer* timer) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Timer timer_handle = + static_cast(timer->implementation())->Handle(); + stream_executor_->start_timer(&device_, stream_handle, timer_handle, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool StopTimer(Stream* stream, Timer* timer) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Timer timer_handle = + static_cast(timer->implementation())->Handle(); + stream_executor_->stop_timer(&device_, stream_handle, timer_handle, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + port::Status BlockHostForEvent(Stream* stream, Event* event) { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + stream_executor_->block_host_for_event(&device_, event_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + port::Status BlockHostUntilDone(Stream* stream) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + + // If `block_host_until_done` is set, use it. + if (stream_executor_->block_host_until_done != nullptr) { + stream_executor_->block_host_until_done(&device_, stream_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + // Create and record an event and then wait for it. + SP_Event event_handle; + stream_executor_->create_event(&device_, &event_handle, c_status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); + stream_executor_->record_event(&device_, stream_handle, event_handle, + c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + if (!s.ok()) { + stream_executor_->destroy_event(&device_, event_handle); + return s; + } + stream_executor_->block_host_for_event(&device_, event_handle, + c_status.get()); + stream_executor_->destroy_event(&device_, event_handle); + return StatusFromTF_Status(c_status.get()); + } + + port::Status GetStatus(Stream* stream) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + stream_executor_->get_stream_status(&device_, stream_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + int PlatformDeviceCount() override { return visible_device_count_; } + port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + return port::UnimplementedError( + "EnablePeerAccessTo is not supported by pluggable device."); + } + bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { + return false; + } + + bool DeviceMemoryUsage(int64* free, int64* total) const override { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + return stream_executor_->device_memory_usage( + &device_, reinterpret_cast(free), + reinterpret_cast(total)); + } + + // Creates a new DeviceDescription object. + // Ownership is transferred to the caller. + port::StatusOr> CreateDeviceDescription() + const override { + // TODO(annarev): Figure out if we need to support more description fields. + internal::DeviceDescriptionBuilder builder; + builder.set_name(platform_name_); + // TODO(annarev): `Also supports_unified_memory` in DeviceDescription. + return builder.Build(); + } + + // Each call creates a new instance of the platform-specific implementation of + // the corresponding interface type. + std::unique_ptr CreateEventImplementation() + override { + return std::unique_ptr( + new CEvent(&device_, stream_executor_)); + } + std::unique_ptr CreateKernelImplementation() + override { + LOG(FATAL) + << "CreateKernelImplementation is not supported by pluggable device."; + } + std::unique_ptr GetStreamImplementation() + override { + return std::unique_ptr( + new CStream(&device_, stream_executor_)); + } + std::unique_ptr GetTimerImplementation() override { + return std::unique_ptr( + new CTimer(&device_, stream_executor_, timer_fns_)); + } + + private: + SP_Device device_; + SP_StreamExecutor* stream_executor_; + SP_Platform* platform_; + SP_PlatformFns* platform_fns_; + SP_TimerFns* timer_fns_; + std::string platform_name_; + int visible_device_count_; +}; +} // namespace + +CPlatform::CPlatform(SP_Platform platform, + void (*destroy_platform)(SP_Platform*), + SP_PlatformFns platform_fns, + void (*destroy_platform_fns)(SP_PlatformFns*), + SP_StreamExecutor stream_executor, SP_TimerFns timer_fns) + : platform_(std::move(platform)), + destroy_platform_(destroy_platform), + platform_fns_(std::move(platform_fns)), + destroy_platform_fns_(destroy_platform_fns), + stream_executor_(std::move(stream_executor)), + timer_fns_(std::move(timer_fns)), + name_(platform.name) {} + +CPlatform::~CPlatform() { + executor_cache_.DestroyAllExecutors(); + platform_fns_.destroy_stream_executor(&platform_, &stream_executor_); + platform_fns_.destroy_timer_fns(&platform_, &timer_fns_); + destroy_platform_(&platform_); + destroy_platform_fns_(&platform_fns_); +} + +port::StatusOr> +CPlatform::DescriptionForDevice(int ordinal) const { + // TODO(annarev): see if we can get StreamExecutor instance + // and call GetDeviceDescription. executor_cache_.Get would need + // to be made const for it to work. + internal::DeviceDescriptionBuilder builder; + builder.set_name(name_); + return builder.Build(); +} +port::StatusOr CPlatform::ExecutorForDevice(int ordinal) { + stream_executor::StreamExecutorConfig config; + config.ordinal = ordinal; + return GetExecutor(config); +} +port::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& plugin_config) { + StreamExecutorConfig config; + config.ordinal = ordinal; + config.plugin_config = plugin_config; + return GetExecutor(config); +} +port::StatusOr CPlatform::GetExecutor( + const StreamExecutorConfig& config) { + return executor_cache_.GetOrCreate( + config, [&]() { return GetUncachedExecutor(config); }); +} +port::StatusOr> CPlatform::GetUncachedExecutor( + const StreamExecutorConfig& config) { + // Fill device creation params + SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE}; + SP_Device device{SP_DEVICE_STRUCT_SIZE}; + device_params.device = &device; + device_params.ext = nullptr; + device_params.ordinal = config.ordinal; + OwnedTFStatus c_status(TF_NewStatus()); + + // Create Device + platform_fns_.create_device(&platform_, &device_params, c_status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPDevice(device)); + + auto executor = absl::make_unique( + std::move(device), &stream_executor_, &platform_, &platform_fns_, + &timer_fns_, name_, platform_.visible_device_count); + auto result = absl::make_unique(this, std::move(executor), + config.ordinal); + return result; +} + +port::Status InitStreamExecutorPlugin(void* dso_handle) { + tensorflow::Env* env = tensorflow::Env::Default(); + + // Step 1: Load symbol for `TF_InitPlugin` + void* dso_symbol; + TF_RETURN_IF_ERROR( + env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol)); + + // Step 2: Call `TF_InitPlugin` + auto init_fn = reinterpret_cast(dso_symbol); + return InitStreamExecutorPlugin(init_fn); +} + +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) { + SE_PlatformRegistrationParams params{ + SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE}; + SP_Platform platform{SP_PLATFORM_STRUCT_SIZE}; + SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE}; + params.major_version = SE_MAJOR; + params.minor_version = SE_MINOR; + params.patch_version = SE_PATCH; + params.platform = &platform; + params.platform_fns = &platform_fns; + + OwnedTFStatus c_status(TF_NewStatus()); + init_fn(¶ms, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params)); + TF_RETURN_IF_ERROR(ValidateSPPlatform(platform)); + TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns)); + + // Fill stream executor creation params + SE_CreateStreamExecutorParams se_params{ + SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE}; + SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE}; + se_params.stream_executor = &se; + + // Create StreamExecutor + platform_fns.create_stream_executor(&platform, &se_params, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform)); + + SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE}; + platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); + + platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); + + // Register new platform + std::string platform_name = std::string(platform.name); + std::unique_ptr cplatform( + new stream_executor::CPlatform( + std::move(platform), params.destroy_platform, std::move(platform_fns), + params.destroy_platform_fns, std::move(se), std::move(timer_fns))); + SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + std::move(cplatform))); + + // TODO(annarev): Add pluggable device registration here. + return port::Status::OK(); +} +} // namespace stream_executor diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h new file mode 100644 index 00000000000..796b4e95121 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -0,0 +1,439 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for StreamExecutor. The API is under active development and eventually +// should allow registering a pluggable device with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plugin or core implementation: +// * SE_ : set/filled by core unless explicitly marked otherwise. +// * SP_ : set/filled by plugin unless explicitly marked otherwise. +// * We use `struct_size` for version checking. It is exempt from the `SE/SP` +// rule above and should be set both by core and the plugin. +// * For example, `create_device` function receives `SP_Device*` as input +// with `struct_size` populated by core. The plugin is responsible for +// setting `struct_size` as well, along with all other fields. +// * Refer to "TensorFlow Versioning Strategy" section at +// https://github.com/tensorflow/community/pull/257/files. +// * Note that the API is still under active development and doesn't have +// versioning guarantees yet. +// * `void* ext` is a free-form field that can be populated by +// a plugin in `SP_*` structs or potential future extension points in `SE_` +// structs. +// +// Example usage: +// +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule +// // above and should be set both by core and the plugin." +// SP_Device device { SP_DEVICE_STRUCT_SIZE }; +// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ; +// params.device = &device; +// +// /* Plugin code below */ +// constexpr char DEVICE_NAME[] = "MyDevice"; +// constexpr char DEVICE_TYPE[] = "GPU"; +// +// void create_device(const SP_Platform* platform, +// SE_CreateDeviceParams* params, TF_Status* status) { +// // Custom actions based on TensorFlow's view of SP_Device. +// OnTFDeviceView(params->device->struct_size); +// params->device = { SP_DEVICE_STRUCT_SIZE }; +// params->device->device_handle = get_my_device_handle(device->ordinal); +// params->device->ordinal = params->ordinal; +// ... +// } +// +// void destroy_device(const SP_Platform* platform, SP_Device* device) { +// delete_my_device_handle(device->device_handle); +// } +// +// void SE_InitPlugin( +// SE_PlatformRegistrationParams* params, +// TF_Status* status) { +// params->platform = { SP_PLATFORM_STRUCT_SIZE }; +// // Values such as `name` and `type` must outlive SE_InitPlugin call. +// params->platform->name = DEVICE_NAME; +// params->platform->type = DEVICE_TYPE; +// params->platform->visible_device_count = 2; +// params->platform_fns->create_device = create_device; +// params->platform_fns->destroy_device = destroy_device; +// ... +// } + +#define SE_MAJOR 0 +#define SE_MINOR 0 +#define SE_PATCH 1 + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct SP_Stream_st* SP_Stream; +typedef struct SP_Event_st* SP_Event; +typedef struct SP_Timer_st* SP_Timer; +// Takes `callback_arg` passed to `host_callback` as the first argument. +typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const); + +typedef struct SP_TimerFns { + size_t struct_size; + void* ext; // reserved for future use + uint64_t (*nanoseconds)(SP_Timer timer); +} SP_TimerFns; + +#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds) + +typedef struct SP_AllocatorStats { + size_t struct_size; + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + + int8_t has_bytes_limit; + int64_t bytes_limit; + + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + + int8_t has_bytes_reservable_limit; + int64_t bytes_reservable_limit; + + int64_t largest_free_block_bytes; +} SP_AllocatorStats; + +#define SP_ALLOCATORSTATS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes) + +// Potential states for an SP_Event. If `poll_for_status` returns anything aside +// from kPending or kComplete, an error has occurred; kUnknown is a bad state. +typedef enum SE_EventStatus { + SE_EVENT_UNKNOWN, + SE_EVENT_ERROR, + SE_EVENT_PENDING, + SE_EVENT_COMPLETE, +} SE_EventStatus; + +// Memory allocation information. +// This matches DeviceMemoryBase defined here: +// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57 +typedef struct SP_DeviceMemoryBase { + size_t struct_size; + void* ext; // free-form data set by plugin + // Platform-dependent value representing allocated memory. + void* opaque; + uint64_t size; // Size in bytes of this allocation. + uint64_t payload; // Value for plugin's use +} SP_DeviceMemoryBase; + +#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload) + +typedef struct SP_Device { + size_t struct_size; + void* ext; // free-form data set by plugin + int32_t ordinal; // device index + + // Device vendor can store handle to their device representation + // here. + void* device_handle; +} SP_Device; + +#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle) + +typedef struct SE_CreateDeviceParams { + size_t struct_size; + void* ext; // reserved for future use + int32_t ordinal; // device index + + SP_Device* device; // Input/output, struct_size set by TF for plugin to read. + // Subsequently plugin fills the entire struct. +} SE_CreateDeviceParams; + +#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateDeviceParams, device) + +typedef struct SP_StreamExecutor { + size_t struct_size; + void* ext; // reserved for future use + + /*** ALLOCATION CALLBACKS ***/ + // Synchronously allocates `size` bytes on the underlying platform and returns + // `SP_DeviceMemoryBase` representing that allocation. In the case of failure, + // nullptr is returned. + // `memory_space` is reserved for a potential future usage and should be set + // to 0. + void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space, + SP_DeviceMemoryBase* mem); + + // Deallocate the device memory previously allocated via this interface. + // Deallocation of a nullptr-representative value is permitted. + void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory); + + // Allocates a region of host memory and registers it with the platform API. + // Memory allocated in this manner is required for use in asynchronous memcpy + // operations, such as `memcpy_dtoh`. + void* (*host_memory_allocate)(const SP_Device* device, uint64_t size); + + // Deallocates a region of host memory allocated by `host_memory_allocate`. + void (*host_memory_deallocate)(const SP_Device* device, void* mem); + + // Allocates unified memory space of the given size, if supported. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes); + + // Deallocates unified memory space previously allocated with + // `unified_memory_allocate`. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void (*unified_memory_deallocate)(const SP_Device* device, void* location); + + // Fills SP_AllocatorStats with allocator statistics, if it is available. + // If it is not available, return false. + TF_Bool (*get_allocator_stats)(const SP_Device* device, + SP_AllocatorStats* stats); + // Fills the underlying device memory usage information, if it is + // available. If it is not available (false is returned), free/total need not + // be initialized. + TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free, + int64_t* total); + + /*** STREAM CALLBACKS ***/ + // Creates SP_Stream. This call should also allocate stream + // resources on the underlying platform and initializes its + // internals. + void (*create_stream)(const SP_Device* device, SP_Stream* stream, + TF_Status* status); + + // Destroys SP_Stream and deallocates any underlying resources. + void (*destroy_stream)(const SP_Device* device, SP_Stream stream); + + // Causes `dependent` to not begin execution until `other` has finished its + // last-enqueued work. + void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent, + SP_Stream other, TF_Status* status); + + // Without blocking the device, retrieve the current stream status. + void (*get_stream_status)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + + /*** EVENT CALLBACKS ***/ + // Create SP_Event. Performs platform-specific allocation and initialization + // of an event. + void (*create_event)(const SP_Device* device, SP_Event* event, + TF_Status* status); + + // Destroy SE_Event and perform any platform-specific deallocation and + // cleanup of an event. + void (*destroy_event)(const SP_Device* device, SP_Event event); + + // Requests the current status of the event from the underlying platform. + SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event); + // Inserts the specified event at the end of the specified stream. + void (*record_event)(const SP_Device* device, SP_Stream stream, + SP_Event event, TF_Status* status); + + // Wait for the specified event at the end of the specified stream. + void (*wait_for_event)(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status); + + /*** TIMER CALLBACKS ***/ + // Creates SP_Timer. Allocates timer resources on the underlying platform + // and initializes its internals, setting `timer` output variable. Sets + // values in `timer_fns` struct. + void (*create_timer)(const SP_Device* device, SP_Timer* timer, + TF_Status* status); + + // Destroy timer and deallocates timer resources on the underlying platform. + void (*destroy_timer)(const SP_Device* device, SP_Timer timer); + + // Records a start event for an interval timer. + void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + // Records a stop event for an interval timer. + void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + /*** MEMCPY CALLBACKS ***/ + // Enqueues a memcpy operation onto stream, with a host destination location + // `host_dst` and a device memory source, with target size `size`. + void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a host memory source, with target size `size`. + void (*memcpy_htod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, const void* host_src, + uint64_t size, TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a device memory source, with target size `size`. + void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the device source to the host destination. + void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the host source to the device destination. + void (*sync_memcpy_htod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is copied from the + // device source to the device destination. + void (*sync_memcpy_dtod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Causes the host code to synchronously wait for the event to complete. + void (*block_host_for_event)(const SP_Device* device, SP_Event event, + TF_Status* status); + + // [Optional] + // Causes the host code to synchronously wait for operations entrained onto + // stream to complete. Effectively a join on the asynchronous device + // operations enqueued on the stream before this program point. + // If not set, then corresponding functionality will be implemented + // by registering an event on the `stream` and waiting for it using + // `block_host_for_event`. + void (*block_host_until_done)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + + // Synchronizes all activity occurring in the StreamExecutor's context (most + // likely a whole device). + void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); + + // Enqueues on a stream a user-specified function to be run on the host. + // `callback_arg` should be passed as the first argument to `callback_fn`. + TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream, + SE_StatusCallbackFn callback_fn, void* callback_arg); +} SP_StreamExecutor; + +#define SP_STREAMEXECUTOR_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_StreamExecutor, host_callback) + +typedef struct SE_CreateStreamExecutorParams { + size_t struct_size; + void* ext; // reserved for future use + + SP_StreamExecutor* stream_executor; // output, to be filled by plugin +} SE_CreateStreamExecutorParams; + +#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor) + +typedef struct SP_Platform { + size_t struct_size; + + void* ext; // free-form data set by plugin + + // Platform name. Must be null-terminated. + const char* name; + + // Device type name, for example GPU. Must be null-terminated. + const char* type; + + // Number of visible devices + size_t visible_device_count; + + // Whether this platform supports unified memory. + // Unified memory is a single memory address space accessible from any device. + TF_Bool supports_unified_memory; +} SP_Platform; + +#define SP_PLATFORM_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_Platform, supports_unified_memory) + +typedef struct SP_PlatformFns { + size_t struct_size; + + void* ext; // reserved for future use + + // Callbacks for creating/destroying SP_Device. + void (*create_device)(const SP_Platform* platform, + SE_CreateDeviceParams* params, TF_Status* status); + + // Clean up fields inside SP_Device that were allocated + // by the plugin. `device` itself should not be deleted here. + void (*destroy_device)(const SP_Platform* platform, SP_Device* device); + + // Callbacks for creating/destroying SP_StreamExecutor. + void (*create_stream_executor)(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* status); + // Clean up fields inside SP_StreamExecutor that were allocated + // by the plugin. `stream_executor` itself should not be deleted here. + void (*destroy_stream_executor)(const SP_Platform* platform, + SP_StreamExecutor* stream_executor); + + // Callbacks for creating/destroying SP_TimerFns. + void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer, + TF_Status* status); + + void (*destroy_timer_fns)(const SP_Platform* platform, + SP_TimerFns* timer_fns); +} SP_PlatformFns; + +#define SP_PLATFORM_FNS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns) + +typedef struct SE_PlatformRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // StreamExecutor C API version. + int32_t major_version; + int32_t minor_version; + int32_t patch_version; + + SP_Platform* platform; // output, set by plugin + SP_PlatformFns* platform_fns; // output, set by plugin + // Clean up fields inside SP_Platform that were allocated + // by the plugin. `platform` itself should not be deleted here. + void (*destroy_platform)(SP_Platform* platform); // out, set by plugin + void (*destroy_platform_fns)( + SP_PlatformFns* platform_fns); // out, set by plugin +} SE_PlatformRegistrationParams; + +#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns) + +void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h new file mode 100644 index 00000000000..079c3661453 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Classes and utilities that work with StreamExecutor C API for internal use. +// This includes functions used for device registration and interfaces needed +// for testing. +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform.h" + +namespace stream_executor { + +// Plugin initialization function that a device plugin +// must define. +typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, + TF_Status* const); + +// Registers StreamExecutor platform. +port::Status InitStreamExecutorPlugin(void* dso_handle); + +// Allow registering a StreamExecutor plugin using a function (used for +// testing). +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn); + +class CPlatform : public Platform { + public: + explicit CPlatform(SP_Platform platform, + void (*destroy_platform)(SP_Platform*), + SP_PlatformFns platform_fns, + void (*destroy_platform_fns)(SP_PlatformFns*), + SP_StreamExecutor stream_executor, SP_TimerFns timer_fns); + ~CPlatform() override; + + Id id() const override { return const_cast(&plugin_id_value_); } + const std::string& Name() const override { return name_; } + int VisibleDeviceCount() const override { + return platform_.visible_device_count; + } + port::StatusOr> DescriptionForDevice( + int ordinal) const override; + port::StatusOr ExecutorForDevice(int ordinal) override; + port::StatusOr ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& plugin_config) override; + port::StatusOr GetExecutor( + const StreamExecutorConfig& config) override; + port::StatusOr> GetUncachedExecutor( + const StreamExecutorConfig& config) override; + + // Trace listener is not supported + void RegisterTraceListener(std::unique_ptr listener) override { + LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device"; + } + void UnregisterTraceListener(TraceListener* listener) override {} + + void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } + + private: + SP_Platform platform_; + void (*destroy_platform_)(SP_Platform*); + SP_PlatformFns platform_fns_; + void (*destroy_platform_fns_)(SP_PlatformFns*); + SP_StreamExecutor stream_executor_; + SP_TimerFns timer_fns_; + const std::string name_; + int plugin_id_value_; + stream_executor::ExecutorCache executor_cache_; +}; + +} // namespace stream_executor +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc new file mode 100644 index 00000000000..c280a3975b7 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -0,0 +1,883 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0(the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/timer.h" + +struct SP_Stream_st { + explicit SP_Stream_st(int id) : stream_id(id) {} + int stream_id; +}; + +struct SP_Event_st { + explicit SP_Event_st(int id) : event_id(id) {} + int event_id; +}; + +struct SP_Timer_st { + explicit SP_Timer_st(int id) : timer_id(id) {} + int timer_id; +}; + +namespace stream_executor { +namespace { +constexpr int DEVICE_COUNT = 2; +constexpr char DEVICE_NAME[] = "MyDevice"; +constexpr char DEVICE_TYPE[] = "GPU"; + +/*** Create SP_StreamExecutor (with empty functions) ***/ +void allocate(const SP_Device* const device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* const mem) {} +void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) { +} +void* host_memory_allocate(const SP_Device* const device, uint64_t size) { + return nullptr; +} +void host_memory_deallocate(const SP_Device* const device, void* mem) {} +TF_Bool get_allocator_stats(const SP_Device* const device, + SP_AllocatorStats* const stats) { + return true; +} +TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free, + int64_t* const total) { + return true; +} +void create_stream(const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) { + stream = nullptr; +} +void destroy_stream(const SP_Device* const device, SP_Stream stream) {} +void create_stream_dependency(const SP_Device* const device, + SP_Stream dependent, SP_Stream other, + TF_Status* const status) {} +void get_stream_status(const SP_Device* const device, SP_Stream stream, + TF_Status* const status) {} +void create_event(const SP_Device* const device, SP_Event* event, + TF_Status* const status) { + event = nullptr; +} +void destroy_event(const SP_Device* const device, SP_Event event) {} +SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) { + return SE_EVENT_UNKNOWN; +} +void record_event(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void wait_for_event(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void create_timer(const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) {} +void destroy_timer(const SP_Device* const device, SP_Timer timer) {} +void start_timer(const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) {} +void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer, + TF_Status* const status) {} +void memcpy_dtoh(const SP_Device* const device, SP_Stream stream, + void* host_dst, const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) {} +void memcpy_htod(const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, const void* host_src, + uint64_t size, TF_Status* const status) {} +void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) {} +void sync_memcpy_htod(const SP_Device* const device, + SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, + TF_Status* const status) {} +void block_host_for_event(const SP_Device* const device, SP_Event event, + TF_Status* const status) {} +void synchronize_all_activity(const SP_Device* const device, + TF_Status* const status) {} +TF_Bool host_callback(SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) { + return true; +} + +void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { + *se = {SP_STREAMEXECUTOR_STRUCT_SIZE}; + se->allocate = allocate; + se->deallocate = deallocate; + se->host_memory_allocate = host_memory_allocate; + se->host_memory_deallocate = host_memory_deallocate; + se->get_allocator_stats = get_allocator_stats; + se->device_memory_usage = device_memory_usage; + se->create_stream = create_stream; + se->destroy_stream = destroy_stream; + se->create_stream_dependency = create_stream_dependency; + se->get_stream_status = get_stream_status; + se->create_event = create_event; + se->destroy_event = destroy_event; + se->get_event_status = get_event_status; + se->record_event = record_event; + se->wait_for_event = wait_for_event; + se->create_timer = create_timer; + se->destroy_timer = destroy_timer; + se->start_timer = start_timer; + se->stop_timer = stop_timer; + se->memcpy_dtoh = memcpy_dtoh; + se->memcpy_htod = memcpy_htod; + se->sync_memcpy_dtoh = sync_memcpy_dtoh; + se->sync_memcpy_htod = sync_memcpy_htod; + se->block_host_for_event = block_host_for_event; + se->synchronize_all_activity = synchronize_all_activity; + se->host_callback = host_callback; +} + +/*** Create SP_TimerFns ***/ +uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; } + +void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) { + timer_fns->nanoseconds = nanoseconds; +} + +/*** Create SP_Platform ***/ +void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultTimerFns(timer_fns); +} +void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {} + +void create_stream_executor(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultStreamExecutor(params->stream_executor); +} +void destroy_stream_executor(const SP_Platform* platform, + SP_StreamExecutor* se) {} + +void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device->struct_size = SP_DEVICE_STRUCT_SIZE; +} +void destroy_device(const SP_Platform* platform, SP_Device* device) {} + +void PopulateDefaultPlatform(SP_Platform* platform, + SP_PlatformFns* platform_fns) { + *platform = {SP_PLATFORM_STRUCT_SIZE}; + platform->name = DEVICE_NAME; + platform->type = DEVICE_TYPE; + platform->visible_device_count = DEVICE_COUNT; + platform_fns->create_device = create_device; + platform_fns->destroy_device = destroy_device; + platform_fns->create_stream_executor = create_stream_executor; + platform_fns->destroy_stream_executor = destroy_stream_executor; + platform_fns->create_timer_fns = create_timer_fns; + platform_fns->destroy_timer_fns = destroy_timer_fns; +} + +void destroy_platform(SP_Platform* const platform) {} +void destroy_platform_fns(SP_PlatformFns* const platform_fns) {} + +/*** Registration tests ***/ +TEST(StreamExecutor, SuccessfulRegistration) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + port::Status status = InitStreamExecutorPlugin(plugin_init); + TF_ASSERT_OK(status); + port::StatusOr maybe_platform = + MultiPlatformManager::PlatformWithName("MyDevice"); + TF_ASSERT_OK(maybe_platform.status()); + Platform* platform = maybe_platform.ConsumeValueOrDie(); + ASSERT_EQ(platform->Name(), DEVICE_NAME); + ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); + + port::StatusOr maybe_executor = + platform->ExecutorForDevice(0); + TF_ASSERT_OK(maybe_executor.status()); + StreamExecutor* executor = maybe_executor.ConsumeValueOrDie(); + ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice"); +} + +TEST(StreamExecutor, NameNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = nullptr; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); +} + +TEST(StreamExecutor, CreateDeviceNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform_fns->create_device = nullptr; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ(status.error_message(), + "'create_device' field in SP_PlatformFns must be set."); +} + +TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->supports_unified_memory = true; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ( + status.error_message(), + "'unified_memory_allocate' field in SP_StreamExecutor must be set."); +} + +/*** StreamExecutor behavior tests ***/ +class StreamExecutorTest : public ::testing::Test { + protected: + StreamExecutorTest() {} + void SetUp() override { + PopulateDefaultPlatform(&platform_, &platform_fns_); + PopulateDefaultStreamExecutor(&se_); + PopulateDefaultTimerFns(&timer_fns_); + } + void TearDown() override {} + + StreamExecutor* GetExecutor(int ordinal) { + if (!cplatform_) { + cplatform_ = absl::make_unique( + platform_, destroy_platform, platform_fns_, destroy_platform_fns, se_, + timer_fns_); + } + port::StatusOr maybe_executor = + cplatform_->ExecutorForDevice(ordinal); + TF_CHECK_OK(maybe_executor.status()); + return maybe_executor.ConsumeValueOrDie(); + } + SP_Platform platform_; + SP_PlatformFns platform_fns_; + SP_StreamExecutor se_; + SP_TimerFns timer_fns_; + std::unique_ptr cplatform_; +}; + +TEST_F(StreamExecutorTest, Allocate) { + se_.allocate = [](const SP_Device* const device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* const mem) { + mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE; + mem->opaque = malloc(size); + mem->size = size; + }; + se_.deallocate = [](const SP_Device* const device, + SP_DeviceMemoryBase* const mem) { + EXPECT_EQ(mem->size, 2 * sizeof(int)); + free(mem->opaque); + mem->opaque = nullptr; + mem->size = 0; + }; + StreamExecutor* executor = GetExecutor(0); + DeviceMemory mem = executor->AllocateArray(2); + ASSERT_NE(mem.opaque(), nullptr); + ASSERT_EQ(mem.size(), 2 * sizeof(int)); + executor->Deallocate(&mem); + ASSERT_EQ(mem.opaque(), nullptr); +} + +TEST_F(StreamExecutorTest, HostMemoryAllocate) { + static bool allocate_called = false; + static bool deallocate_called = false; + se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) { + allocate_called = true; + return malloc(size); + }; + se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) { + free(mem); + deallocate_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(allocate_called); + void* mem = executor->HostMemoryAllocate(8); + ASSERT_NE(mem, nullptr); + ASSERT_TRUE(allocate_called); + ASSERT_FALSE(deallocate_called); + executor->HostMemoryDeallocate(mem); + ASSERT_TRUE(deallocate_called); +} + +TEST_F(StreamExecutorTest, UnifiedMemoryAllocate) { + static bool allocate_called = false; + static bool deallocate_called = false; + se_.unified_memory_allocate = [](const SP_Device* const device, + uint64_t size) { + allocate_called = true; + return malloc(size); + }; + se_.unified_memory_deallocate = [](const SP_Device* const device, void* mem) { + free(mem); + deallocate_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(allocate_called); + void* mem = executor->UnifiedMemoryAllocate(8); + ASSERT_NE(mem, nullptr); + ASSERT_TRUE(allocate_called); + ASSERT_FALSE(deallocate_called); + executor->UnifiedMemoryDeallocate(mem); + ASSERT_TRUE(deallocate_called); +} + +TEST_F(StreamExecutorTest, GetAllocatorStats) { + se_.get_allocator_stats = [](const SP_Device* const device, + SP_AllocatorStats* const stat) -> TF_Bool { + stat->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE; + stat->bytes_in_use = 123; + return true; + }; + + StreamExecutor* executor = GetExecutor(0); + absl::optional optional_stats = executor->GetAllocatorStats(); + ASSERT_TRUE(optional_stats.has_value()); + AllocatorStats stats = optional_stats.value(); + ASSERT_EQ(stats.bytes_in_use, 123); +} + +TEST_F(StreamExecutorTest, DeviceMemoryUsage) { + se_.device_memory_usage = [](const SP_Device* const device, + int64_t* const free, + int64_t* const total) -> TF_Bool { + *free = 45; + *total = 7; + return true; + }; + + StreamExecutor* executor = GetExecutor(0); + int64 free = 0; + int64 total = 0; + executor->DeviceMemoryUsage(&free, &total); + ASSERT_EQ(free, 45); + ASSERT_EQ(total, 7); +} + +TEST_F(StreamExecutorTest, CreateStream) { + static bool stream_created = false; + static bool stream_deleted = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(14); + stream_created = true; + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { + auto custom_stream = static_cast(stream); + ASSERT_EQ(custom_stream->stream_id, 14); + delete custom_stream; + stream_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(stream_created); + Stream* stream = new Stream(executor); + stream->Init(); + ASSERT_TRUE(stream->ok()); + ASSERT_TRUE(stream_created); + ASSERT_FALSE(stream_deleted); + delete stream; + ASSERT_TRUE(stream_deleted); +} + +TEST_F(StreamExecutorTest, CreateStreamDependency) { + static bool create_stream_dependency_called = false; + se_.create_stream_dependency = [](const SP_Device* const device, + SP_Stream dependent, SP_Stream other, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + create_stream_dependency_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream dependent(executor); + dependent.Init(); + Stream other(executor); + other.Init(); + ASSERT_FALSE(create_stream_dependency_called); + dependent.ThenWaitFor(&other); + ASSERT_TRUE(create_stream_dependency_called); +} + +TEST_F(StreamExecutorTest, StreamStatus) { + static bool status_ok = true; + se_.get_stream_status = [](const SP_Device* const device, SP_Stream stream, + TF_Status* const status) -> void { + if (status_ok) { + TF_SetStatus(status, TF_OK, ""); + } else { + TF_SetStatus(status, TF_INTERNAL, "Test error"); + } + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.RefreshStatus()); + status_ok = false; + auto updated_status = stream.RefreshStatus(); + ASSERT_FALSE(stream.ok()); + ASSERT_EQ(updated_status.error_message(), "Test error"); +} + +TEST_F(StreamExecutorTest, CreateEvent) { + static bool event_created = false; + static bool event_deleted = false; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(123); + event_created = true; + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { + auto custom_event = static_cast(event); + ASSERT_EQ(custom_event->event_id, 123); + delete custom_event; + event_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(event_created); + Event* event = new Event(executor); + event->Init(); + ASSERT_TRUE(event_created); + ASSERT_FALSE(event_deleted); + delete event; + ASSERT_TRUE(event_deleted); +} + +TEST_F(StreamExecutorTest, PollForEventStatus) { + static SE_EventStatus event_status = SE_EVENT_COMPLETE; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(123); + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { delete event; }; + se_.get_event_status = [](const SP_Device* const device, + SP_Event event) -> SE_EventStatus { + EXPECT_EQ(event->event_id, 123); + return event_status; + }; + + StreamExecutor* executor = GetExecutor(0); + Event event(executor); + event.Init(); + ASSERT_EQ(event.PollForStatus(), Event::Status::kComplete); + event_status = SE_EVENT_ERROR; + ASSERT_EQ(event.PollForStatus(), Event::Status::kError); +} + +TEST_F(StreamExecutorTest, RecordAndWaitForEvent) { + static bool record_called = false; + static bool wait_called = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(1); + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { delete stream; }; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(2); + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { delete event; }; + se_.record_event = [](const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) { + EXPECT_EQ(stream->stream_id, 1); + EXPECT_EQ(event->event_id, 2); + TF_SetStatus(status, TF_OK, ""); + record_called = true; + }; + se_.wait_for_event = [](const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) { + EXPECT_EQ(stream->stream_id, 1); + EXPECT_EQ(event->event_id, 2); + TF_SetStatus(status, TF_OK, ""); + wait_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Event event(executor); + event.Init(); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(record_called); + stream.ThenRecordEvent(&event); + ASSERT_TRUE(record_called); + ASSERT_FALSE(wait_called); + stream.ThenWaitFor(&event); + ASSERT_TRUE(wait_called); +} + +TEST_F(StreamExecutorTest, CreateTimer) { + static bool timer_created = false; + static bool timer_deleted = false; + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(25); + timer_created = true; + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { + auto custom_timer = static_cast(timer); + EXPECT_EQ(custom_timer->timer_id, 25); + delete custom_timer; + timer_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(timer_created); + Stream stream(executor); + stream.Init(); + Timer* timer = new Timer(executor); + stream.InitTimer(timer); + ASSERT_TRUE(stream.ok()); + ASSERT_TRUE(timer_created); + ASSERT_FALSE(timer_deleted); + delete timer; + ASSERT_TRUE(timer_deleted); +} + +TEST_F(StreamExecutorTest, StartTimer) { + static bool start_called = false; + static bool stop_called = false; + static TF_Code start_timer_status = TF_OK; + static TF_Code stop_timer_status = TF_OK; + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(7); + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { delete timer; }; + se_.start_timer = [](const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) { + TF_SetStatus(status, start_timer_status, ""); + EXPECT_EQ(timer->timer_id, 7); + start_called = true; + }; + se_.stop_timer = [](const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) { + TF_SetStatus(status, stop_timer_status, ""); + EXPECT_EQ(timer->timer_id, 7); + stop_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + Timer timer(executor); + stream.InitTimer(&timer); + + // Check both start and stop succeed + ASSERT_FALSE(start_called); + stream.ThenStartTimer(&timer); + ASSERT_TRUE(start_called); + ASSERT_FALSE(stop_called); + stream.ThenStopTimer(&timer); + ASSERT_TRUE(stop_called); + + // Check start timer fails + ASSERT_TRUE(stream.ok()); + start_timer_status = TF_UNKNOWN; + stream.ThenStartTimer(&timer); + ASSERT_FALSE(stream.ok()); + + // Check stop timer fails + start_timer_status = TF_OK; + stop_timer_status = TF_UNKNOWN; + Stream stream2(executor); + stream2.Init(); + Timer timer2(executor); + stream2.InitTimer(&timer2); + stream2.ThenStartTimer(&timer2); + ASSERT_TRUE(stream2.ok()); + stream2.ThenStopTimer(&timer2); + ASSERT_FALSE(stream2.ok()); +} + +TEST_F(StreamExecutorTest, TimerFns) { + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(25000); + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { delete timer; }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + Timer timer(executor); + stream.InitTimer(&timer); + // Our test nanoseconds function just returns value + // passed to SP_Timer_st constructor. + ASSERT_EQ(timer.Nanoseconds(), 25000); + ASSERT_EQ(timer.Microseconds(), 25); +} + +TEST_F(StreamExecutorTest, MemcpyToHost) { + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(14); + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { delete stream; }; + + se_.memcpy_dtoh = [](const SP_Device* const device, SP_Stream stream, + void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + EXPECT_EQ(stream->stream_id, 14); + std::memcpy(host_dst, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 34; + int dst_data = 2; + DeviceMemoryBase device_src(&src_data, size); + Stream& stream_ref = stream.ThenMemcpy(&dst_data, device_src, size); + ASSERT_EQ(dst_data, 34); + ASSERT_EQ(stream_ref.implementation(), stream.implementation()); +} + +TEST_F(StreamExecutorTest, MemcpyFromHost) { + se_.memcpy_htod = [](const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, host_src, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + stream.ThenMemcpy(&device_dst, &src_data, size); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, MemcpyDeviceToDevice) { + se_.memcpy_dtod = [](const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + DeviceMemoryBase device_src(&src_data, size); + stream.ThenMemcpy(&device_dst, device_src, size); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, SyncMemcpyToHost) { + se_.sync_memcpy_dtoh = [](const SP_Device* const device, void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(host_dst, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 34; + int dst_data = 2; + DeviceMemoryBase device_src(&src_data, size); + TF_ASSERT_OK(executor->SynchronousMemcpyD2H(device_src, size, &dst_data)); + ASSERT_EQ(dst_data, 34); +} + +TEST_F(StreamExecutorTest, SyncMemcpyFromHost) { + se_.sync_memcpy_htod = + [](const SP_Device* const device, SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, host_src, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + TF_ASSERT_OK(executor->SynchronousMemcpyH2D(&src_data, size, &device_dst)); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, SyncMemcpyDeviceToDevice) { + se_.sync_memcpy_dtod = [](const SP_Device* const device, + SP_DeviceMemoryBase* const device_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + DeviceMemoryBase device_src(&src_data, size); + ASSERT_TRUE(executor->SynchronousMemcpy(&device_dst, device_src, size)); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, BlockHostForEvent) { + static bool block_host_for_event_called = false; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) { + *event = new SP_Event_st(357); + }; + se_.destroy_event = [](const SP_Device* const device, SP_Event event) { + delete event; + }; + se_.block_host_for_event = [](const SP_Device* const device, SP_Event event, + TF_Status* const status) -> void { + ASSERT_EQ(event->event_id, 357); + TF_SetStatus(status, TF_OK, ""); + block_host_for_event_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(block_host_for_event_called); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + ASSERT_TRUE(block_host_for_event_called); +} + +TEST_F(StreamExecutorTest, BlockHostUntilDone) { + static bool block_host_until_done_called = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) { + *stream = new SP_Stream_st(58); + }; + se_.destroy_stream = [](const SP_Device* const device, SP_Stream stream) { + delete stream; + }; + se_.block_host_until_done = [](const SP_Device* const device, + SP_Stream stream, + TF_Status* const status) -> void { + ASSERT_EQ(stream->stream_id, 58); + TF_SetStatus(status, TF_OK, ""); + block_host_until_done_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(block_host_until_done_called); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + ASSERT_TRUE(block_host_until_done_called); +} + +TEST_F(StreamExecutorTest, SynchronizeAllActivity) { + static bool synchronize_all_called = false; + se_.synchronize_all_activity = [](const SP_Device* const device, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + synchronize_all_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(synchronize_all_called); + ASSERT_TRUE(executor->SynchronizeAllActivity()); + ASSERT_TRUE(synchronize_all_called); +} + +TEST_F(StreamExecutorTest, HostCallbackOk) { + se_.host_callback = [](SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) -> TF_Bool { + TF_Status* status = TF_NewStatus(); + callback_fn(callback_arg, status); + bool ok = TF_GetCode(status) == TF_OK; + TF_DeleteStatus(status); + return ok; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + std::function callback = []() -> port::Status { + return port::Status::OK(); + }; + stream.ThenDoHostCallbackWithStatus(callback); + ASSERT_TRUE(stream.ok()); +} + +TEST_F(StreamExecutorTest, HostCallbackError) { + se_.host_callback = [](SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) -> TF_Bool { + TF_Status* status = TF_NewStatus(); + callback_fn(callback_arg, status); + bool ok = TF_GetCode(status) == TF_OK; + TF_DeleteStatus(status); + return ok; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + std::function callback = []() -> port::Status { + return port::UnimplementedError("Unimplemented"); + }; + stream.ThenDoHostCallbackWithStatus(callback); + ASSERT_FALSE(stream.ok()); +} +} // namespace +} // namespace stream_executor diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 20a6c5117cf..ed501b5b101 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -261,7 +261,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, size_t len, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); tensorflow::gtl::ArraySlice dimarray( @@ -279,4 +278,73 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, return nullptr; } return tf_tensor; -} \ No newline at end of file +} + +TF_Tensor* TF_ForwardInputOrAllocateOutput( + TF_OpKernelContext* context, int* candidate_input_indices, + int num_candidate_input_indices, int output_index, int64_t* output_dims, + int output_num_dims, int* forwarded_input, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); + + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + tensorflow::gtl::ArraySlice input_indices_array( + candidate_input_indices, num_candidate_input_indices); + tensorflow::gtl::ArraySlice output_dimarray( + reinterpret_cast(output_dims), output_num_dims); + tensorflow::Tensor* output_tensor_pointer; + tensorflow::Status s = cc_ctx->forward_input_or_allocate_output( + input_indices_array, output_index, + tensorflow::TensorShape(output_dimarray), &output_tensor_pointer, + forwarded_input); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return tf_tensor_output; +} + +TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, + int64_t* dims, int num_dims, + TF_AllocatorAttributes* attributes, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); + TF_SetStatus(status, TF_OK, ""); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + tensorflow::gtl::ArraySlice dimarray( + reinterpret_cast(dims), num_dims); + if (attributes && !attributes->struct_size) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "TF_AllocatorAttributes struct " + "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE"); + return nullptr; + } + tensorflow::AllocatorAttributes allocator_attr; + if (attributes && attributes->on_host) { + allocator_attr.set_on_host(true); + } + tensorflow::Status s; + tensorflow::Tensor tensor; + s = cc_ctx->allocate_temp(static_cast(dtype), + tensorflow::TensorShape(dimarray), &tensor, + allocator_attr); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + TF_Tensor* tf_tensor; + tf_tensor = TF_TensorFromTensor(tensor, &s); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return tf_tensor; +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index c7138a39c73..489aa5399a5 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). @@ -199,6 +200,26 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int64_t* dims, int num_dims, size_t len, TF_Status* status); +// Tries to forward one of the inputs given in input_indices to +// output[output_index]. If none of the given inputs can be forwarded, calls +// allocate_output() to allocate a new output buffer. The index of the +// forwarded input will be assign to output argument forwarded_input (if it's +// not nullptr). If no inputs are forwarded, forwarded_input will be assigned +// -1. +TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput( + TF_OpKernelContext* context, int* candidate_input_indices, + int num_candidate_input_indices, int output_index, int64_t* output_dims, + int output_num_dims, int* forwarded_input, TF_Status* status); + +// Allocates a temporary Tensor of the specified type and shape. The +// Tensor must not be used after kernel construction is +// complete. +// +// num_dims must equal the size of array dims +TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp( + TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims, + TF_AllocatorAttributes* alloc_attrs, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 5fec068bd73..6bb2b347a30 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -39,6 +39,33 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "histogram_summary_op", + prefix = "histogram_summary_op", + deps = [ + "//tensorflow/c:kernels", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "merge_summary_op", + prefix = "merge_summary_op", + deps = [ + "//tensorflow/c:kernels", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + tf_gen_op_libs( op_lib_names = ["bitcast"], deps = [ @@ -59,6 +86,24 @@ tf_gen_op_libs( ], ) +tf_gen_op_libs( + op_lib_names = ["histogram_summary"], + deps = [ + "//tensorflow/c:ops", + "//tensorflow/c:tf_status", + "//tensorflow/core:lib", + ], +) + +tf_gen_op_libs( + op_lib_names = ["merge_summary"], + deps = [ + "//tensorflow/c:ops", + "//tensorflow/c:tf_status", + "//tensorflow/core:lib", + ], +) + tf_cc_test( name = "bitcast_op_test", srcs = ["bitcast_op_test.cc"], @@ -87,6 +132,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "summary_op_benchmark_test", + size = "small", + srcs = ["summary_op_benchmark_test.cc"], + deps = [ + ":summary_op", + "//tensorflow/c:kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "tensor_shape_utils", srcs = ["tensor_shape_utils.cc"], @@ -122,6 +184,8 @@ filegroup( name = "android_all_op_kernels", srcs = [ "bitcast_op.cc", + "histogram_summary_op.cc", + "merge_summary_op.cc", "summary_op.cc", "tensor_shape_utils.cc", "tensor_shape_utils.h", @@ -133,6 +197,8 @@ filegroup( name = "android_all_ops", srcs = [ "ops/bitcast.cc", + "ops/histogram_summary.cc", + "ops/merge_summary.cc", "ops/summary.cc", ], ) diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc new file mode 100644 index 00000000000..143a2675a05 --- /dev/null +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -0,0 +1,165 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" + +namespace { + +// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status. +struct TFTensorDeleter { + void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); } +}; + +struct TFStatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; + +// Struct that wraps TF_Tensor and TF_Status to delete once out of scope. +using Safe_TF_TensorPtr = std::unique_ptr; +using Safe_TF_StatusPtr = std::unique_ptr; + +// Used to pass the operation node name from kernel construction to +// kernel computation. +struct HistogramSummaryOp { + std::string op_node_name; +}; + +void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) { + HistogramSummaryOp* kernel = new HistogramSummaryOp; + TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx); + kernel->op_node_name = + std::string(string_view_name.data, string_view_name.len); + return kernel; +} + +void HistogramSummaryOp_Delete(void* kernel) { + delete static_cast(kernel); +} + +template +void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + HistogramSummaryOp* k = static_cast(kernel); + TF_Tensor* tags; + TF_Tensor* values; + Safe_TF_StatusPtr status(TF_NewStatus()); + TF_GetInput(ctx, 0, &tags, status.get()); + Safe_TF_TensorPtr safe_tags_ptr(tags); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + TF_GetInput(ctx, 1, &values, status.get()); + Safe_TF_TensorPtr safe_values_ptr(values); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + if (TF_NumDims(safe_tags_ptr.get()) != 0) { + TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar"); + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + // Cast values to array to access tensor elements by index + auto values_array = static_cast(TF_TensorData(safe_values_ptr.get())); + tensorflow::histogram::Histogram histo; + for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) { + const double double_val = static_cast(values_array[i]); + if (Eigen::numext::isnan(double_val)) { + std::ostringstream err; + err << "Nan in summary histogram for: " << k->op_node_name; + TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } else if (Eigen::numext::isinf(double_val)) { + std::ostringstream err; + err << "Infinity in Histogram for: " << k->op_node_name; + TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + histo.Add(double_val); + } + tensorflow::Summary s; + tensorflow::Summary::Value* v = s.add_value(); + const tensorflow::tstring& tag = + *(static_cast(TF_TensorData(safe_tags_ptr.get()))); + v->set_tag(tag.data(), tag.size()); + histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); + + Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput( + /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0), + /*dims=*/nullptr, /*num_dims=*/0, + /*len=*/sizeof(tensorflow::tstring), status.get())); + + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + tensorflow::tstring* output_tstring = reinterpret_cast( + TF_TensorData(summary_tensor.get())); + CHECK(SerializeToTString(s, output_tstring)); +} + +template +void RegisterHistogramSummaryOpKernel() { + TF_Status* status = TF_NewStatus(); + { + auto* builder = TF_NewKernelBuilder( + "HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create, + &HistogramSummaryOp_Compute, &HistogramSummaryOp_Delete); + TF_KernelBuilder_TypeConstraint( + builder, "T", + static_cast(tensorflow::DataTypeToEnum::v()), status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint"; + TF_RegisterKernelBuilder("HistogramSummary", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering Histogram Summmary kernel"; + } + TF_DeleteStatus(status); +} + +// A dummy static variable initialized by a lambda whose side-effect is to +// register the Histogram Summary kernel. +TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() { + if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) { + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + } + return true; +}(); +} // namespace diff --git a/tensorflow/c/kernels/merge_summary_op.cc b/tensorflow/c/kernels/merge_summary_op.cc new file mode 100644 index 00000000000..e45029319e5 --- /dev/null +++ b/tensorflow/c/kernels/merge_summary_op.cc @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tstring.h" + +namespace { + +// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status +struct TFTensorDeleter { + void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); } +}; + +struct TFStatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; + +// Struct that wraps TF_Tensor and TF_Status to delete once out of scope +using Safe_TF_TensorPtr = std::unique_ptr; +using Safe_TF_StatusPtr = std::unique_ptr; + +// dummy functions used for kernel registration +void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; } + +void MergeSummaryOp_Delete(void* kernel) {} + +void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + tensorflow::Summary s; + std::unordered_set tags; + Safe_TF_StatusPtr status(TF_NewStatus()); + for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) { + TF_Tensor* input; + TF_GetInput(ctx, input_num, &input, status.get()); + Safe_TF_TensorPtr safe_input_ptr(input); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + auto tags_array = + static_cast(TF_TensorData(safe_input_ptr.get())); + for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) { + const tensorflow::tstring& s_in = tags_array[i]; + tensorflow::Summary summary_in; + if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) { + TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, + "Could not parse one of the summary inputs"); + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + for (int v = 0; v < summary_in.value_size(); ++v) { + // This tag is unused by the TensorSummary op, so no need to check for + // duplicates. + const tensorflow::string& tag = summary_in.value(v).tag(); + if ((!tag.empty()) && !tags.insert(tag).second) { + std::ostringstream err; + err << "Duplicate tag " << tag << " found in summary inputs "; + TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + *s.add_value() = summary_in.value(v); + } + } + } + Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput( + /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0), + /*dims=*/nullptr, /*num_dims=*/0, + /*len=*/sizeof(tensorflow::tstring), status.get())); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + tensorflow::tstring* output_tstring = reinterpret_cast( + TF_TensorData(summary_tensor.get())); + CHECK(SerializeToTString(s, output_tstring)); +} + +void RegisterMergeSummaryOpKernel() { + TF_Status* status = TF_NewStatus(); + { + auto* builder = TF_NewKernelBuilder( + "MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create, + &MergeSummaryOp_Compute, &MergeSummaryOp_Delete); + TF_RegisterKernelBuilder("MergeSummary", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering Merge Summmary kernel"; + } + TF_DeleteStatus(status); +} + +// A dummy static variable initialized by a lambda whose side-effect is to +// register the Histogram Summary kernel. +TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() { + if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) { + RegisterMergeSummaryOpKernel(); + } + return true; +}(); + +} // namespace diff --git a/tensorflow/c/kernels/ops/histogram_summary.cc b/tensorflow/c/kernels/ops/histogram_summary.cc new file mode 100644 index 00000000000..67d4d1b0a5b --- /dev/null +++ b/tensorflow/c/kernels/ops/histogram_summary.cc @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +static void histogram_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx); + TF_ShapeInferenceContextSetOutput(ctx, 0, result, status); + TF_DeleteShapeHandle(result); +} + +void Register_HistogramSummaryOp() { + TF_Status* status = TF_NewStatus(); + + TF_OpDefinitionBuilder* op_builder = + TF_NewOpDefinitionBuilder("HistogramSummary"); + TF_OpDefinitionBuilderAddInput(op_builder, "tag: string"); + TF_OpDefinitionBuilderAddInput(op_builder, "values: T"); + TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string"); + TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype = DT_FLOAT"); + TF_OpDefinitionBuilderSetShapeInferenceFunction( + op_builder, &histogram_summary_shape_inference_fn); + + TF_RegisterOpDefinition(op_builder, status); + CHECK_EQ(TF_GetCode(status), TF_OK) + << "HistogramSummary op registration failed: " << TF_Message(status); + TF_DeleteStatus(status); +} + +TF_ATTRIBUTE_UNUSED static bool HistogramSummaryOpRegistered = []() { + if (SHOULD_REGISTER_OP("HistogramSummary")) { + Register_HistogramSummaryOp(); + } + return true; +}(); diff --git a/tensorflow/c/kernels/ops/merge_summary.cc b/tensorflow/c/kernels/ops/merge_summary.cc new file mode 100644 index 00000000000..991c469fff6 --- /dev/null +++ b/tensorflow/c/kernels/ops/merge_summary.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +static void merge_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx); + TF_ShapeInferenceContextSetOutput(ctx, 0, result, status); + TF_DeleteShapeHandle(result); +} + +void Register_MergeSummaryOp() { + TF_Status* status = TF_NewStatus(); + + TF_OpDefinitionBuilder* op_builder = + TF_NewOpDefinitionBuilder("MergeSummary"); + TF_OpDefinitionBuilderAddInput(op_builder, "inputs: N * string"); + TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string"); + TF_OpDefinitionBuilderAddAttr(op_builder, "N: int >= 1"); + TF_OpDefinitionBuilderSetShapeInferenceFunction( + op_builder, &merge_summary_shape_inference_fn); + + TF_RegisterOpDefinition(op_builder, status); + CHECK_EQ(TF_GetCode(status), TF_OK) + << "MergeSummary op registration failed: " << TF_Message(status); + TF_DeleteStatus(status); +} + +TF_ATTRIBUTE_UNUSED static bool MergeSummaryOpRegistered = []() { + if (SHOULD_REGISTER_OP("MergeSummary")) { + Register_MergeSummaryOp(); + } + return true; +}(); diff --git a/tensorflow/c/kernels/summary_op.cc b/tensorflow/c/kernels/summary_op.cc index bd528da4165..ac7eced0ae7 100644 --- a/tensorflow/c/kernels/summary_op.cc +++ b/tensorflow/c/kernels/summary_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc new file mode 100644 index 00000000000..887a86066d3 --- /dev/null +++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor tags(DT_STRING, shape); + Tensor values(DT_FLOAT, shape); + for (int i = 0; i < tags.NumElements(); ++i) { + tags.flat()(i) = tag; + values.flat()(i) = value; + } + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary") + .Input(test::graph::Constant(g, tags)) + .Input(test::graph::Constant(g, values)) + .Attr("T", DT_FLOAT) + .Finalize(g, &ret)); + return g; +} + +// Macro used to parse initializer list for tensorshape +#define DIMARGS(...) \ + { __VA_ARGS__ } +// // Random parameters for testing +constexpr char longTagParam[] = "LONGTAG____________________________"; +constexpr float largeValueParam = 2352352.2623433; + +#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ + void BM_ScalarSummary##name##device(int iters) { \ + testing::StopTiming(); \ + TensorShape tensorshape(DIMARGS dims); \ + auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK(BM_ScalarSummary##name##device); + +BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2); +// Benchmark for large shapes +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2); +// Benchmark for large tag tstring +BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2); +// Benchmark for large values +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 3c8ac934428..c9df2cc34d1 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -368,6 +368,16 @@ class DeviceKernelOpTest : public OpsTestBase { #endif }; +// Validates that the tensor has shape and type corresponding to +// dims and dtype. +void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims, + TF_DataType dtype); + +// Copies data of length tensor_size_bytes from values to tensor. +template +void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, + TF_OpKernelContext* ctx); + REGISTER_OP("AllocateOutputOp1").Output("output1: float"); TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) { @@ -379,22 +389,11 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) { TF_Tensor* output = TF_AllocateOutput( /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1, /*len=*/tensor_size_bytes, s); - EXPECT_EQ(TF_OK, TF_GetCode(s)); - EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); - EXPECT_EQ(1, TF_NumDims(output)); - EXPECT_EQ(1, TF_Dim(output, 0)); + validate_tensor(output, &dim, 1, TF_FLOAT); // Set output to 3 - float* data = reinterpret_cast(TF_TensorData(output)); - float value = 3.0f; -#if GOOGLE_CUDA - OpKernelContext* cc_ctx = reinterpret_cast(ctx); - cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value, - tensor_size_bytes); -#else - *data = value; -#endif - + float values[1] = {3.0f}; + set_tensor_data(output, values, tensor_size_bytes, ctx); TF_DeleteStatus(s); TF_DeleteTensor(output); }; @@ -417,12 +416,8 @@ TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) { TF_Tensor* output = TF_AllocateOutput( /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1, /*len=*/0, s); - EXPECT_EQ(TF_OK, TF_GetCode(s)); - EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); - EXPECT_EQ(1, TF_NumDims(output)); - EXPECT_EQ(0, TF_Dim(output, 0)); - + validate_tensor(output, &dim, 1, TF_FLOAT); TF_DeleteStatus(s); TF_DeleteTensor(output); }; @@ -442,27 +437,16 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) { TF_Status* s = TF_NewStatus(); // Allocate 2x3 output int64_t dim[2] = {2, 3}; - size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT); + size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6; TF_Tensor* output = TF_AllocateOutput( /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim, /*num_dims=*/2, /*len=*/tensor_size_bytes, s); EXPECT_EQ(TF_OK, TF_GetCode(s)); - EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); - EXPECT_EQ(2, TF_NumDims(output)); - EXPECT_EQ(2, TF_Dim(output, 0)); - EXPECT_EQ(3, TF_Dim(output, 1)); + validate_tensor(output, dim, 2, TF_FLOAT); // Set output to [1 2 3 4 5 6] - void* data = TF_TensorData(output); - float value[6] = {1, 2, 3, 4, 5, 6}; -#if GOOGLE_CUDA - OpKernelContext* cc_ctx = reinterpret_cast(ctx); - cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value, - tensor_size_bytes); -#else - memcpy(data, value, tensor_size_bytes); -#endif - + float values[6] = {1, 2, 3, 4, 5, 6}; + set_tensor_data(output, values, tensor_size_bytes, ctx); TF_DeleteStatus(s); TF_DeleteTensor(output); }; @@ -474,4 +458,200 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) { EXPECT_EQ("Tensor", output->DebugString(100)); } + +REGISTER_OP("AllocateTempOp1").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + // Allocate scalar TF_Tensor + TF_Status* s = TF_NewStatus(); + int64_t dim = 1; + TF_AllocatorAttributes alloc_attrs; + alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; +#if GOOGLE_CUDA + alloc_attrs.on_host = 0; +#else + alloc_attrs.on_host = 1; +#endif + TF_Tensor* output = TF_AllocateTemp( + /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim, + /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s); + size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + validate_tensor(output, &dim, 1, TF_FLOAT); + + // Set TF_Tensor value to 3 + float values[1] = {3.0f}; + set_tensor_data(output, values, tensor_size_bytes, ctx); + TF_SetOutput(ctx, 0, output, s); + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} + +REGISTER_OP("AllocateTempOp0").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + TF_Status* s = TF_NewStatus(); + // Allocate empty TF_Tensor + int64_t dim = 0; + TF_AllocatorAttributes alloc_attrs; + alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; +#if GOOGLE_CUDA + alloc_attrs.on_host = 0; +#else + alloc_attrs.on_host = 1; +#endif + TF_Tensor* output = TF_AllocateTemp( + /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim, + /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + validate_tensor(output, &dim, 1, TF_FLOAT); + TF_SetOutput(ctx, 0, output, s); + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} + +REGISTER_OP("AllocateTempOp2x3").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + TF_Status* s = TF_NewStatus(); + size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT); + // Allocate 2x3 TF_Tensor + int64_t dim[2] = {2, 3}; + TF_AllocatorAttributes alloc_attrs; + alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; +#if GOOGLE_CUDA + alloc_attrs.on_host = 0; +#else + alloc_attrs.on_host = 1; +#endif + TF_Tensor* output = TF_AllocateTemp( + /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim, + /*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + validate_tensor(output, dim, 2, TF_FLOAT); + + // Set TF_Tensor values to [1 2 3 4 5 6] + float values[6] = {1, 2, 3, 4, 5, 6}; + set_tensor_data(output, values, tensor_size_bytes, ctx); + TF_SetOutput(ctx, 0, output, s); + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} + +TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) { + const char* node_name = "TestForwardInputOrAllocateOutputKernel"; + const char* op_name = "BazOp"; + const char* device_name = "FakeDeviceName"; + + REGISTER_OP(op_name) + .Input("input1: float") + .Input("input2: float") + .Output("output1: float") + .Attr("SomeDataTypeAttr: type"); + + // A kernel whose Compute function that forwards a scalar input to output + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + TF_Status* s = TF_NewStatus(); + int candidate_input_indices[1] = {0}; + int forwarded_input; + int64_t output_dims[1] = {}; + TF_Tensor* output = TF_ForwardInputOrAllocateOutput( + /*context=*/ctx, candidate_input_indices, + /*num_candidate_input_indices=*/1, + /*output_index=*/0, output_dims, /*output_num_dims=*/0, + &forwarded_input, /*status=*/s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + EXPECT_EQ(forwarded_input, 0); + EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); + EXPECT_EQ(0, TF_NumDims(output)); + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(node_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr); + p.device = &dummy_device; + AllocatorAttributes alloc_attrs; + p.output_attr_array = &alloc_attrs; + + Tensor t(123.0f); + + gtl::InlinedVector inputs; + // GetFakeKernel requires a NodeDef with two inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, node_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } +} + +void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims, + TF_DataType dtype) { + EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor)); + EXPECT_EQ(num_dims, TF_NumDims(tensor)); + for (int i = 0; i < num_dims; ++i) { + EXPECT_EQ(dims[i], TF_Dim(tensor, i)); + } +} + +template +void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, + TF_OpKernelContext* ctx) { + T* data = reinterpret_cast(TF_TensorData(tensor)); +#if GOOGLE_CUDA + OpKernelContext* cc_ctx = reinterpret_cast(ctx); + cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values, + tensor_size_bytes); +#else + memcpy(data, values, tensor_size_bytes); +#endif +} } // namespace tensorflow diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc index bf6bf069fff..13c9e6ac208 100644 --- a/tensorflow/c/logging.cc +++ b/tensorflow/c/logging.cc @@ -28,6 +28,7 @@ void TF_Log(TF_LogLevel level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); switch (level) { case TF_INFO: LOG(INFO) << message; @@ -48,6 +49,7 @@ void TF_VLog(int level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); VLOG(level) << message; } @@ -55,5 +57,6 @@ void TF_DVLog(int level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); DVLOG(level) << message; } diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index ff8085f1229..a895e608159 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, // Returns a "status" from "tf_status". tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status); +namespace internal { +struct TF_StatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; +} // namespace internal + +using TF_StatusPtr = std::unique_ptr; + } // namespace tensorflow #endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 0feb986ce44..39d2683226f 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -288,7 +288,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) { if (!tensor.CopyFrom(src, src.shape())) { return nullptr; } - return new TF_Tensor{new tensorflow::TensorInterface(tensor)}; + return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))}; } Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { diff --git a/tensorflow/c/tf_tensor.h b/tensorflow/c/tf_tensor.h index acdf053e63a..e0a026f984f 100644 --- a/tensorflow/c/tf_tensor.h +++ b/tensorflow/c/tf_tensor.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -45,6 +46,16 @@ limitations under the License. extern "C" { #endif +// Allocator Attributes used for tensor allocation. +typedef struct TF_AllocatorAttributes { + size_t struct_size; + // Set boolean to 1 for CPU allocation, else 0. + TF_Bool on_host; +} TF_AllocatorAttributes; + +#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \ + TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host) + // -------------------------------------------------------------------------- // TF_Tensor holds a multi-dimensional array of elements of a single data type. // For all types other than TF_STRING, the data buffer stores elements diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e1fad8e697a..8602bfafff8 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -558,6 +558,7 @@ tf_gen_op_wrappers_cc( "io_ops", "linalg_ops", "list_ops", + "map_ops", "logging_ops", "lookup_ops", "manip_ops", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index a67d349bab7..fddbcfec6e6 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -47,6 +47,7 @@ cc_library( # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate # tf_lib depending on the build platform. + "@com_google_absl//absl/memory:memory", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ]), @@ -56,7 +57,7 @@ tf_cc_test( name = "reader_test", srcs = ["reader_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -148,7 +149,7 @@ tf_cc_test( name = "bundle_v2_test", srcs = ["bundle_v2_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -165,12 +166,13 @@ tf_cc_test( name = "saved_model_bundle_test", srcs = ["saved_model_bundle_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ ":constants", ":loader", + ":reader", ":signature_constants", ":tag_constants", "//tensorflow/core:lib", @@ -186,7 +188,7 @@ tf_cc_test( name = "saved_model_bundle_lite_test", srcs = ["saved_model_bundle_lite_test.cc"], data = [ - ":saved_model_half_plus_two", + ":saved_model_test_files", ], linkstatic = 1, deps = [ @@ -225,7 +227,7 @@ py_binary( # TODO(b/32673259): add a test to continuously validate these files. filegroup( - name = "saved_model_half_plus_two", + name = "saved_model_test_files", srcs = glob([ "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", @@ -234,9 +236,15 @@ filegroup( "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", "testdata/VarsAndArithmeticObjectGraph/**", + "testdata/fuzz_generated/**", ]), ) +alias( + name = "saved_model_half_plus_two", + actual = ":saved_model_test_files", +) + exports_files( glob([ "testdata/half_plus_two_pbtxt/**", @@ -246,5 +254,6 @@ exports_files( "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", "testdata/VarsAndArithmeticObjectGraph/**", + "testdata/fuzz_generated/**", ]), ) diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD index 3e9a671a61f..9640848ebf5 100644 --- a/tensorflow/cc/saved_model/experimental/public/BUILD +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -51,8 +51,32 @@ cc_library( deps = [ ":concrete_function", ":concrete_function_list", + ":signature_def_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/cc/experimental/base/public:runtime", "//tensorflow/cc/experimental/base/public:status", ], ) + +cc_library( + name = "signature_def_function", + hdrs = [ + "signature_def_function.h", + ], + deps = [ + ":signature_def_function_metadata", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/experimental/saved_model/public:signature_def_function", + "//tensorflow/cc/experimental/base/public:status", + ], +) + +cc_library( + name = "signature_def_function_metadata", + hdrs = [ + "signature_def_function_metadata.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h index 04018bf2aab..c2bfb4dcf83 100644 --- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function.h" namespace tensorflow { namespace experimental { @@ -80,8 +81,8 @@ class SavedModelAPI { // If status is not OK, returns nullptr. Otherwise, returns a // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer // is bound to SavedModelAPI it was loaded from. - ConcreteFunction* GetSignatureDefFunction(const std::string& function_path, - Status* status); + SignatureDefFunction* GetSignatureDefFunction( + const std::string& function_path, Status* status); // Lists all Conrete Functions available from the SavedModel. std::vector ListFunctions(); @@ -140,14 +141,14 @@ inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( return ConcreteFunction::wrap(function); } -inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction( +inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction( const std::string& function_path, Status* status) { - TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction( + TF_SignatureDefFunction* function = TF_GetSavedModelSignatureDefFunction( saved_model_.get(), function_path.c_str(), status->GetTFStatus()); if (!status->ok()) { return nullptr; } - return ConcreteFunction::wrap(function); + return SignatureDefFunction::wrap(function); } inline std::vector SavedModelAPI::ListFunctions() { diff --git a/tensorflow/cc/saved_model/experimental/public/signature_def_function.h b/tensorflow/cc/saved_model/experimental/public/signature_def_function.h new file mode 100644 index 00000000000..bc72d208e87 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/signature_def_function.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctions are functions that correspond to either: +// "signatures" saved from a TF2 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/save.py#L830-L854 +// Or the "SignatureDefMap" saved from TF1 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/load_v1_in_v2_test.py#L170-L174 +// In both cases, a SignatureDef is serialized as a SignatureDef protobuf: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/core/protobuf/meta_graph.proto#L260-L330 +// and represents a computation defined by a TF subgraph. +// These Signatures were primarily designed to be interoperable with the legacy +// TF 1 Session-based C++ SavedModelBundle loading APIs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/cc/saved_model/loader.h#L96-L108 +// SignatureDefFunctions have different semantics from regular TF2 +// ConcreteFunctions, and are mainly intended provide a serving-friendly +// transition point from the TF1 Session API. +// First, SignatureDefFunctions have different calling conventions. +// SignatureDefFunctions' inputs and outputs are constrained to **flattened +// lists of TensorHandles only**. They do not support more exotic input/output +// types (like optionals, generators, etc). Additionally, this flattening means +// they will not preserve the exact interface of the original tf.function they +// were traced from, as things like composite tensors decay into their +// internal dense tensor representation. +// Second, all inputs and outputs are "named", and these names are load bearing +// (eg: they are part of the interface of tensorflow_serving): +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L21 +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L39 +// The name of each input/output is stored in the corresponding tf::Argument in +// SignatureDefFunctionMetadata::arguments(). Users must ensure the order of +// TensorHandles passed to the function matches with the order of named +// arguments. Similarly the name of the outputs is stored in +// SignatureDefFunctionMetadata::returns(). +class SignatureDefFunction final { + public: + // Returns FunctionMetadata associated with this ConcreteFunction. + const SignatureDefFunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static SignatureDefFunction* wrap(TF_SignatureDefFunction* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunction* unwrap(SignatureDefFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const SignatureDefFunctionMetadata* +SignatureDefFunction::GetFunctionMetadata() { + return SignatureDefFunctionMetadata::wrap( + TF_SignatureDefFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h b/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h new file mode 100644 index 00000000000..6cb01bf1a26 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctionMetadata stores additional information on each input +// and output's names, dtypes, and shape. +class SignatureDefFunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class SignatureDefFunction; + static SignatureDefFunctionMetadata* wrap( + TF_SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunctionMetadata* unwrap( + SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f9c720a2ba2..ecefe7d0406 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) { return Status::OK(); } -Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, - const SessionOptions& session_options, - std::unique_ptr* session) { - Session* session_p = nullptr; - TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); - session->reset(session_p); - TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def())); - return (*session)->Create(meta_graph_def.graph_def()); -} - Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); tensor.scalar()() = value; @@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, nullptr /* outputs */, &run_metadata, session); } -Status ReadSavedModelDebugInfoIfPresent( - const string& export_dir, - std::unique_ptr* debug_info_proto) { - LOG(INFO) << "Reading SavedModel debug info (if present) from: " - << export_dir; +} // namespace - const string debug_info_pb_path = - io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); - if (Env::Default()->FileExists(debug_info_pb_path).ok()) { - GraphDebugInfo debug_info; - TF_RETURN_IF_ERROR( - ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); - *debug_info_proto = - absl::make_unique(std::move(debug_info)); - } - return Status::OK(); +SavedModelBundleInterface::~SavedModelBundleInterface() {} + +Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session) { + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); + TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def())); + return (*session)->Create(meta_graph.graph_def()); } Status LoadSavedModelInternal(const SessionOptions& session_options, @@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { - const uint64 read_start_microseconds = Env::Default()->NowMicros(); TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, &bundle->meta_graph_def)); TF_RETURN_IF_ERROR( ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info)); - TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( - bundle->meta_graph_def, session_options, &bundle->session)); - - std::vector asset_file_defs; - TF_RETURN_IF_ERROR( - internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs)); - TF_RETURN_IF_ERROR( - RunRestore(run_options, export_dir, - bundle->meta_graph_def.saver_def().restore_op_name(), - bundle->meta_graph_def.saver_def().filename_tensor_name(), - asset_file_defs, bundle->session.get())); - // Record walltime spent in restoring graph from disk, but postpone metric - // increments until graph init finishes. - const uint64 restore_graph_walltime = - GetLatencyMicroseconds(read_start_microseconds); - - const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); - string init_op_name; - TF_RETURN_IF_ERROR( - internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); - TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, - asset_file_defs, bundle->session.get(), - init_op_name)); - load_latency_by_stage->GetCell(export_dir, "restore_graph") - ->Add(restore_graph_walltime); - // Record wall time spent in init op. - load_latency_by_stage->GetCell(export_dir, "init_graph") - ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); + TF_RETURN_IF_ERROR(LoadMetagraphIntoSession( + session_options, bundle->meta_graph_def, &bundle->session)); + TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def, + export_dir, &bundle->session)); return Status::OK(); } -} // namespace - -SavedModelBundleInterface::~SavedModelBundleInterface() {} - Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, @@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session { }; } // namespace +Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, const string& export_dir, + std::unique_ptr* session) { + const uint64 read_start_microseconds = Env::Default()->NowMicros(); + std::vector asset_file_defs; + TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); + TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, + meta_graph.saver_def().restore_op_name(), + meta_graph.saver_def().filename_tensor_name(), + asset_file_defs, session->get())); + // Record walltime spent in restoring graph from disk, but postpone metric + // increments until graph init finishes. + const uint64 restore_graph_walltime = + GetLatencyMicroseconds(read_start_microseconds); + + const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); + string init_op_name; + TF_RETURN_IF_ERROR( + internal::GetInitOp(export_dir, meta_graph, &init_op_name)); + TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph, + asset_file_defs, session->get(), init_op_name)); + load_latency_by_stage->GetCell(export_dir, "restore_graph") + ->Add(restore_graph_walltime); + // Record wall time spent in init op. + load_latency_by_stage->GetCell(export_dir, "init_graph") + ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); + return Status::OK(); +} + Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 2b2e44bc619..5ef6070998e 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface { protobuf::Map signatures_; }; +// Restore variable and resources in the SavedModel export dir for the +// indicated metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, const string& export_dir, + std::unique_ptr* session); + +// Initialize a session which wraps this metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session); + /// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to /// the set of tags used at SavedModel build time. Stores a SavedModel bundle in diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index d6d99229372..c1d4736f6b9 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, return Status::OK(); } +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr* debug_info_proto) { + LOG(INFO) << "Reading SavedModel debug info (if present) from: " + << export_dir; + + const string debug_info_pb_path = + io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); + if (Env::Default()->FileExists(debug_info_pb_path).ok()) { + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR( + ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); + *debug_info_proto = + absl::make_unique(std::move(debug_info)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h index 5815108df2a..602f6cb21c1 100644 --- a/tensorflow/cc/saved_model/reader.h +++ b/tensorflow/cc/saved_model/reader.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { @@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, const std::unordered_set& tags, MetaGraphDef* const meta_graph_def); +// Store debug info from the SavedModel export dir. +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr* debug_info_proto); + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index bc630bcaede..b5e8b67a123 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) { EXPECT_FALSE(st.ok()); } +TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) { + const string export_dir = GetDataDependencyFilepath(TestDataSharded()); + std::unique_ptr debug_info_proto; + TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index d6c375c7448..3f258745fa4 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/saved_model/loader.h" - #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { namespace { @@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) { CheckSavedModelBundle(export_dir, bundle); } +TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + MetaGraphDef actual_metagraph; + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &actual_metagraph)); + EXPECT_EQ(actual_metagraph.DebugString(), + bundle.meta_graph_def.DebugString()); +} + +TEST_F(LoaderTest, RestoreSession) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + + SavedModelBundle actual_bundle; + const std::unordered_set tags = {kSavedModelTagServe}; + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &actual_bundle.meta_graph_def)); + TF_ASSERT_OK(LoadMetagraphIntoSession( + session_options, actual_bundle.meta_graph_def, &actual_bundle.session)); + TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def, + export_dir, &actual_bundle.session)); + CheckSavedModelBundle(export_dir, actual_bundle); +} + TEST_F(LoaderTest, NoTagMatch) { SavedModelBundle bundle; RunOptions run_options; @@ -270,6 +308,9 @@ TEST_F(LoaderTest, NegativeShapeDimension) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); + EXPECT_NE( + st.error_message().find("initializes from a tensor with -1 elements"), + std::string::npos); } TEST_F(LoaderTest, ConstNoValue) { @@ -282,6 +323,9 @@ TEST_F(LoaderTest, ConstNoValue) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); + EXPECT_NE( + st.error_message().find("constant tensor but no value has been provided"), + std::string::npos); } } // namespace diff --git a/third_party/sycl/crosstool/BUILD b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/assets/empty similarity index 100% rename from third_party/sycl/crosstool/BUILD rename to tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/assets/empty diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/saved_model.pb similarity index 100% rename from tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value rename to tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3fd3ba2223d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index new file mode 100644 index 00000000000..7357e8d57ed Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/assets/empty b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/assets/empty new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/saved_model.pb similarity index 100% rename from tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape rename to tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3fd3ba2223d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index new file mode 100644 index 00000000000..7357e8d57ed Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape/variables/variables.index differ diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ecbb1a5d200..82375577610 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -15,6 +15,7 @@ package_group( "//tensorflow/compiler/tf2xla:internal", ], packages = [ + "//tensorflow/c/...", "//tensorflow/compiler/tests/...", "//tensorflow/python/...", ], @@ -128,22 +129,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "xla_interpreter_device", - srcs = ["xla_interpreter_device.cc"], - visibility = [":friends"], - deps = [ - ":jit_compilation_passes", - ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep - "@com_google_absl//absl/memory", - ], - alwayslink = 1, -) - cc_library( name = "xla_tensor", srcs = ["xla_tensor.cc"], @@ -211,6 +196,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor/platform", ] @@ -221,16 +207,19 @@ cc_library( "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", + "xla_ops_on_regular_devices.cc", + "xla_platform_info.cc", ], hdrs = [ "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", + "xla_platform_info.h", ], # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], - deps = XLA_DEVICE_DEPS, + deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"], ) cc_library( @@ -341,8 +330,10 @@ cc_library( srcs = ["xla_compilation_cache.cc"], hdrs = ["xla_compilation_cache.h"], deps = [ + ":flags", ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -373,8 +364,11 @@ tf_cc_test( "xla_compilation_cache_test.cc", ], deps = [ + ":flags", ":xla_compilation_cache", + ":xla_cpu_jit", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -394,20 +388,70 @@ cc_library( alwayslink = 1, ) -# Linked by tensorflow core, without registration of jit compilation passes -# which is not necessary to create and run a XlaLocalLaunchBase kernel. -# Linking jit compilation passes could cause programs stuck right now (b/140069592). cc_library( - name = "xla_kernel_creator_util", - srcs = [ - "xla_kernel_creator_util.cc", + name = "get_compiler_ir", + srcs = ["get_compiler_ir.cc"], + hdrs = ["get_compiler_ir.h"], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", + ], + deps = [ + ":common", + ":compilability_check_util", + ":flags", + ":xla_device_no_jit_rewrite_registration", + ":xla_launch_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:core_cpu_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +# Header-only version of "flags" library, for linking from the shared object +# without ODR violations. +cc_library( + name = "get_compiler_ir_hdrs_only", + hdrs = ["get_compiler_ir.h"], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", + ], + deps = [ + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "xla_kernel_creator", + srcs = [ + "xla_kernel_creator.cc", + "xla_kernel_creator.h", + ], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", ], - hdrs = ["xla_kernel_creator_util.h"], - visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"], deps = [ ":common", ":compilability_check_util", ":compilation_passes", + ":flags", + ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", @@ -422,25 +466,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "xla_kernel_creator", - srcs = [ - "xla_kernel_creator.cc", - "xla_kernel_creator.h", - ], - deps = [ - ":compilability_check_util", - ":flags", - ":jit_compilation_passes", - ":xla_kernel_creator_util", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) - tf_cc_test( name = "xla_kernel_creator_test", srcs = [ @@ -632,7 +657,6 @@ cc_library( ":flags", ":resource_operation_safety_analysis", ":shape_inference_helpers", - ":union_find", ":xla_activity_listener", ":xla_cluster_util", "//tensorflow/cc:cc_ops", @@ -651,8 +675,8 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", - "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -729,11 +753,6 @@ tf_cc_test( ], ) -cc_library( - name = "union_find", - hdrs = ["union_find.h"], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -828,6 +847,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:test", + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -914,7 +934,6 @@ cc_library( ":device_util", ":flags", ":resource_operation_safety_analysis", - ":union_find", ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", @@ -923,6 +942,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -946,6 +966,7 @@ tf_cc_test( ":xla_cpu_jit", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/compiler/tf2xla:test_util", @@ -972,6 +993,7 @@ tf_cc_test( ":xla_cpu_jit", "//tensorflow/cc:cc_ops", "//tensorflow/cc:ops", + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 8463c788496..160ea83585d 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -130,17 +130,6 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return fdef_lib; } -FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) { - FunctionDefLibrary fdef_lib; - FunctionDef func = FunctionDefHelper::Create( - /*function_name=*/name, /*in_def=*/{"in: int32"}, - /*out_def=*/{"out: int32"}, - /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}}, - /*ret_def=*/{{"out", "out:output:0"}}); - *fdef_lib.add_function() = std::move(func); - return fdef_lib; -} - TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); @@ -269,6 +258,17 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { } #ifdef GOOGLE_CUDA +FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) { + FunctionDefLibrary fdef_lib; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{"in: int32"}, + /*out_def=*/{"out: int32"}, + /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}}, + /*ret_def=*/{{"out", "out:output:0"}}); + *fdef_lib.add_function() = std::move(func); + return fdef_lib; +} + // This tests a rewrite that only makes sense and is active in a CUDA-enabled // build. Specifically we check that we insert an IdentityN op to avoid extra // device-to-host copies. diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 6d4bc51f1b2..20efbe248d7 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -84,6 +84,43 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, return Status::OK(); } +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. +// +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(absl::Span values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_.size() && + values_[current_index_] <= value) { + if (values_[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; + } + + private: + int current_index_; + const absl::Span values_; +}; + } // anonymous namespace RecursiveCompilabilityChecker::UncompilableNodesMap @@ -518,23 +555,23 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( } } +// Returns `true` iff node has a given `attr` set to `true`. Returns `false` +// both for the missing attr, and the attr set to `false`. +static bool HasBoolAttr(const NodeDef& node, const char* attr) { + const auto& it = node.attr().find(attr); + return it != node.attr().end() && it->second.b(); +} + bool CanCreateXlaKernel(const NodeDef& node_def) { - // If kXlaMustCompileAttr is set on the node_def, use its value. - const auto& it = node_def.attr().find(kXlaMustCompileAttr); - return it != node_def.attr().end() && it->second.b(); + return HasBoolAttr(node_def, kXlaMustCompileAttr); } Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NodeDef& node_def, + const NameAttrList& function, const FunctionBody** fbody, std::vector* constant_arg_indices, std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If node_def is not instantiable, e.g., the function does not exist, - // simply bail out. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); - TF_RETURN_IF_ERROR( flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); *fbody = flr->GetFunctionBody(handle); @@ -564,4 +601,96 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, return Status::OK(); } +tensorflow::MemoryTypeVector GetInputMemoryTypes( + const tensorflow::FunctionBody* fbody, + absl::Span constant_arg_indices, + absl::Span resource_arg_indices) { + // Set input and output memory types. + tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(), + tensorflow::DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(constant_arg_indices); + SinglePassSearch resources_search(resource_arg_indices); + for (size_t i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = tensorflow::HOST_MEMORY; + } + } + return input_memory_types; +} + +tensorflow::MemoryTypeVector GetOutputMemoryTypes( + const tensorflow::FunctionBody* fbody) { + tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(), + tensorflow::DEVICE_MEMORY); + for (size_t i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) { + output_memory_types[i] = tensorflow::HOST_MEMORY; + } + } + return output_memory_types; +} + +static auto const ops_triggering_xla_compilation = + new absl::flat_hash_set{"XlaBroadcastHelper", + "XlaConv", + "XlaDequantize", + "XlaDot", + "XlaDynamicSlice", + "XlaDynamicUpdateSlice", + "XlaEinsum", + "XlaGather", + "XlaIf", + "XlaKeyValueSort", + "XlaPad", + "XlaRecv", + "XlaReduce", + "XlaReduceWindow", + "XlaReplicaId", + "XlaScatter", + "XlaSelectAndScatter", + "XlaSelfAdjointEig", + "XlaSend", + "XlaSharding", + "XlaSort", + "XlaSpmdFullToShardShape", + "XlaSpmdShardToFullShape", + "XlaSvd", + "XlaWhile"}; + +static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { + return node.attr().find(kXlaClusterIdAttr) != node.attr().end() || + HasBoolAttr(node, kXlaMustCompileAttr) || + HasBoolAttr(node, kXlaCompileAttr) || + HasBoolAttr(node, kXlaScopeAttr) || + HasBoolAttr(node, kXlaInternalScopeAttr) || + ops_triggering_xla_compilation->count(node.op()); +} + +bool CanTriggerXlaCompilation(const GraphDef& graph) { + for (const FunctionDef& function : graph.library().function()) { + for (const NodeDef& node : function.node_def()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + } + + for (const NodeDef& node : graph.node()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + + return false; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 3b20784cc29..3c1378bf764 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -126,9 +126,10 @@ class RecursiveCompilabilityChecker { bool allow_inaccurate_ops = false; }; - RecursiveCompilabilityChecker(const OperationFilter* op_filter, - const DeviceType* jit_device_type) - : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {} + RecursiveCompilabilityChecker(OperationFilter op_filter, + DeviceType jit_device_type) + : op_filter_(std::move(op_filter)), + jit_device_type_(std::move(jit_device_type)) {} using UncompilableNodesMap = std::map* constant_arg_indices, std::vector* resource_arg_indices); @@ -282,6 +282,44 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, // set. bool CanCreateXlaKernel(const NodeDef& node_def); +// Returns memory types for the input. +// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of +// indices corresponding to constant and resource arguments respectively. +// +// One might wonder, about the case where a compile-time constant argument +// (which must be in host memory) is also used as an input into an op, +// e.g. `Add`, that expects its inputs in device memory. Here is how it +// works now. +// First, what do we mean by "op expects an input in XYZ memory"? +// There are two types of "ops" here: the tf2xla kernel and the HLO +// computation it builds. The tf2xla kernel needs to retrieve the actual +// numeric value of the compile-time constant tensors, so it really expects +// them to be on in host memory. However, for other inputs, it refers to them +// using xla::ComputationDataHandle, which is just a symbolic handle that +// xla::ComputationBuilder assigns. How does this handle gets assigned for +// constant arguments? Even constant arguments get an _Arg node in the graph +// instantiated for Function compilation. The tf2xla kernel for constant _Arg +// nodes takes the constant value, converts it to XlaLiteral, and feeds it +// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This +// constant XlaLiteral is included in the HLO graph, and subsequently, in +// the actual executable, which is copied to the device before being +// executed. Thus, when this executable runs, the constant is available in +// device memory. +tensorflow::MemoryTypeVector GetInputMemoryTypes( + const tensorflow::FunctionBody* fbody, + absl::Span constant_arg_indices, + absl::Span resource_arg_indices); + +// Returns output memory types. +// +// XlaLaunch kernel keeps all outputs (including constants, which it copies), +// in device memory except for resources. +tensorflow::MemoryTypeVector GetOutputMemoryTypes( + const tensorflow::FunctionBody* fbody); + +// Check whether graph can trigger XLA compilation. +bool CanTriggerXlaCompilation(const GraphDef& graph); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 3ea38e69ad9..3851c66ba1a 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -75,8 +76,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test { op_filter_.allow_inaccurate_ops = false; op_filter_.allow_slow_ops = false; - checker_ = absl::make_unique(&op_filter_, - &device_type_); + checker_ = absl::make_unique(op_filter_, + device_type_); } FunctionLibraryRuntime* GetFunctionLibraryRuntime() { @@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { "unsupported op")); } +TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = identity_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("IdentityFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_FALSE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef sort_func = FunctionDefHelper::Create( + "SortFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = sort_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("SortFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + AttrValue true_attribute; + true_attribute.set_b(true); + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute; + + FunctionDef call_identity = FunctionDefHelper::Create( + "CallIdentity", + /*in_def=*/{"x:float"}, + /*out_def=*/{"z:float"}, /*attr_def=*/{}, + /*node_def=*/ + {{{"func_call"}, + "PartitionedCall", + {"x"}, + {{"Tin", DataTypeSlice({DT_FLOAT})}, + {"Tout", DataTypeSlice({DT_FLOAT})}, + {"f", + FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})}, + {kXlaMustCompileAttr, true}}}}, + /*ret_def=*/{{"z", "func_call:output:0"}}); + + *library.add_function() = identity_func; + *library.add_function() = call_identity; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("CallIdentity"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index 4bea71e8fc1..84e1e36bcf6 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope"; // only when auto_jit is ON. const char* const kXlaInternalScopeAttr = "_XlaInternalScope"; +const char* const kXlaClusterIdAttr = "_xla_compile_id"; + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index 9eb4c2ca2e8..fa983db8df8 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaScopeAttr; // "_XlaScope" extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" +// The id of the compiled cluster. +extern const char* const kXlaClusterIdAttr; // "_xla_compile_id" + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_DEFS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index ed25baa62ff..4a5c79c02d9 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -34,9 +35,6 @@ limitations under the License. namespace tensorflow { -const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = - "_xla_compile_id"; - namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; @@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { string name; // Only consider nodes being compiled. - if (!GetNodeAttr(n->attrs(), - EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) - .ok()) - continue; + if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue; // Early return for any node with a device that is not a CPU or GPU. DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { @@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, retvals[i]->AddAttr("index", i); } - AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), - call_def); + AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def); AddNodeAttr("_variable_start_index", variable_start_index, call_def); // Uniquify the function name. @@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // O(n) pass over the edges. for (const Edge* e : (*graph)->edges()) { if (!e->IsControlEdge() && - e->src()->attrs().Find(kXlaClusterAttr) != nullptr && - e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( "Undeclared output of XLA computation. Some common causes of this " @@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto output = absl::make_unique((*graph)->op_registry()); TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, - &output, flib_def), + EncapsulateSubgraphsInFunctions( + kXlaClusterIdAttr, **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), "EncapsulateXlaComputationsPass failed"); graph->swap(output); return Status::OK(); @@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // while iterating. std::vector launch_nodes; for (Node* n : graph->nodes()) { - const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr); + const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr); if (!name.empty()) { launch_nodes.push_back(n); } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index 3057e4c7469..9931b23fa41 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -34,8 +34,6 @@ namespace tensorflow { // XlaLaunch operators. class EncapsulateXlaComputationsPass : public GraphOptimizationPass { public: - static const char* const kXlaClusterAttr; // _xla_compile_id - Status Run(const GraphOptimizationPassOptions& options) override; // The following methods are public only for unit tests. diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index cc177036591..61c9a3ff9c0 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" @@ -46,19 +47,18 @@ static std::unique_ptr MakeOuterGraph( auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); NodeDef def; - TF_CHECK_OK( - NodeDefBuilder("launch0", function, &flib_def) - .Input(a.node()->name(), 0, DT_INT32) - .Input(b.node()->name(), 0, DT_FLOAT) - .Input(c.node()->name(), 0, DT_INT32) - .Input(d.node()->name(), 0, DT_FLOAT) - .Input(u.node()->name(), 0, DT_RESOURCE) - .Input(v.node()->name(), 0, DT_RESOURCE) - .Input(w.node()->name(), 0, DT_RESOURCE) - .Device("/gpu:0") - .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") - .Attr("_variable_start_index", 4) - .Finalize(&def)); + TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") + .Attr(kXlaClusterIdAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); Status status; Node* launch = scope.graph()->AddNode(def, &status); @@ -107,7 +107,7 @@ static std::unique_ptr MakeBodyGraph() { auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; @@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { : ops::Add(scope.WithOpName("E"), a1, a0); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, - "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); }; add_attrs(e.node()); @@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index d1301a8c40f..ee7daf092da 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -159,7 +159,7 @@ void AllocateAndParseFlags() { device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; - device_flags->tf_xla_enable_xla_devices = true; + device_flags->tf_xla_enable_xla_devices = false; ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; @@ -268,10 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector* flag_list) { AppendMarkForCompilationPassFlagsInternal(flag_list); } -static bool xla_is_enabled = false; +static std::atomic xla_compilation_disabled(false); -void SetXlaIsEnabled() { xla_is_enabled = true; } +void DisableXlaCompilation() { xla_compilation_disabled = true; } -bool IsXlaEnabled() { return xla_is_enabled; } +bool FailOnXlaCompilation() { return xla_compilation_disabled; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 89e20d9f8ea..5612b3b5864 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -162,13 +162,12 @@ MlirCommonFlags* GetMlirCommonFlags(); void AppendMarkForCompilationPassFlags( std::vector* flag_list); -// Makes all future calls to `IsXlaEnabled()` return `true`. -// -// Should only be called when XLA is linked in. -void SetXlaIsEnabled(); +// Disables XLA compilation, forces it to return an error message instead. Can +// be used by a server to ensure that JIT compilation is opt-in. +void DisableXlaCompilation(); -// Returns whether XLA is enabled. -bool IsXlaEnabled(); +// Returns `false` unless `DisableXlaCompilation` was called. +bool FailOnXlaCompilation(); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc index 3ba32f07506..3692d1f3aba 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc @@ -38,10 +38,12 @@ Status ForceXlaConstantsOnHostPass::Run( std::vector constant_arg_indices; std::vector resource_arg_indices; + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node->def(), &function)); + // Force all constants to be on the host memory. TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node->def(), &fbody, &constant_arg_indices, - &resource_arg_indices)); + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); VLOG(3) << "Found constant arg indices: " << absl::StrJoin(constant_arg_indices, ", "); diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc new file mode 100644 index 00000000000..7c6a7583c3a --- /dev/null +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/get_compiler_ir.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +xla::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, Device* dev, + absl::Span inputs) { + NameAttrList function; + function.set_name(std::string{func_name}); + + FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name()); + ResourceMgr* rmgr = dev->resource_manager(); + + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); + + MemoryTypeVector input_memory_types = + GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); + MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + + std::vector variable_infos; + TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( + rmgr, dev, inputs, resource_arg_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rmgr->LookupOrCreate( + rmgr->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache_write_into) { + return BuildXlaCompilationCache(dev, platform_info, cache_write_into); + })); + core::ScopedUnref cache_ref(cache); + + absl::optional tf_allocator_adapter; + + XlaCompiler::Options options = + GenerateCompilerOptions(*cache, *flr, dev, + /*stream=*/nullptr, platform_info, + /*has_ref_vars=*/false, &tf_allocator_adapter); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + compile_options.alias_resource_update = true; + + XlaCompiler compiler(options); + + xla::StatusOr> args = + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arg_indices, inputs, variable_infos); + TF_RETURN_IF_ERROR(args.status()); + + switch (stage) { + case IrExportStage::HLO: { + XlaCompiler::CompilationResult result; + TF_RETURN_IF_ERROR( + compiler.CompileFunction(compile_options, function, *args, &result)); + + TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape, + result.computation->GetProgramShape()); + xla::HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN( + std::unique_ptr new_module, + xla::HloModule::CreateFromProto(result.computation->proto(), config)); + + return new_module->ToString(); + } + case IrExportStage::OPTIMIZED_HLO: { + const XlaCompiler::CompilationResult* compilation_result = nullptr; + xla::LocalExecutable* executable = nullptr; + TF_RETURN_IF_ERROR( + cache->Compile(options, function, *args, compile_options, + XlaCompilationCache::CompileMode::kStrict, + &compilation_result, &executable)); + return executable->executable()->module().ToString(); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h new file mode 100644 index 00000000000..81e5af29279 --- /dev/null +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ +#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +class ProcessFunctionLibraryRuntime; +class Device; +class Tensor; + +enum class IrExportStage { HLO, OPTIMIZED_HLO }; + +// Returns HLO text for a given function `func_name` using library runtime +// `runtime` on a device `dev` with given `inputs`. +xla::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, Device* dev, + absl::Span inputs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 38e33a60657..12b40b1c83b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -63,38 +64,6 @@ namespace tensorflow { namespace { -XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { - DeviceType device_type = ctx->device_type(); - se::Platform::Id platform_id = nullptr; - const XlaDevice::Metadata* xla_device_metadata = nullptr; - se::DeviceMemoryAllocator* custom_allocator = nullptr; - - if (ctx->device_type() == DeviceType(DEVICE_CPU)) { - platform_id = se::host::kHostPlatformId; - } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { - platform_id = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which xla_allocator above uses) as on an XlaDevice, this is a dummy - // allocator that returns XlaTensor objects. The XlaCompiler needs a real - // allocator to allocate real buffers. - platform_id = xla_device_metadata->platform()->id(); - custom_allocator = - xla_device_metadata->client()->backend().memory_allocator(); - } - - return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - custom_allocator); -} // A closure describing how to run a compiled version of a TensorFlow function. // @@ -178,31 +147,6 @@ class XlaExecutableClosureStore { TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; -// Return allocator from platform info if non-null, or populate and return a -// pointer to the allocator adapter with allocator from context. -// -// This is necessary because for XLA devices the underlying TF allocator returns -// dummy tensors. -se::DeviceMemoryAllocator* GetAllocator( - absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { - if (platform_info.custom_allocator()) { - return platform_info.custom_allocator(); - } - if (!ctx->op_device_context()) { - // Stream is not set for the host platform. - se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) - .ValueOrDie(); - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); - return &tf_allocator_adapter->value(); - } - // platform_info. - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), - ctx->op_device_context()->stream()); - return &tf_allocator_adapter->value(); -} - } // namespace XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, @@ -214,68 +158,13 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, constants_(constants), resources_(resources), function_(function), - platform_info_(PlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), has_ref_vars_(has_ref_vars) {} -static Status BuildCompilationCache(OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, - XlaCompilationCache** cache) { - if (platform_info.xla_device_metadata()) { - *cache = new XlaCompilationCache( - platform_info.xla_device_metadata()->client(), - platform_info.xla_device_metadata()->jit_device_type()); - return Status::OK(); - } - - auto platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); - if (!platform.ok()) { - return platform.status(); - } - - xla::StatusOr compiler_for_platform = - xla::Compiler::GetForPlatform(platform.ValueOrDie()); - if (!compiler_for_platform.ok()) { - // In some rare cases (usually in unit tests with very small clusters) we - // may end up transforming an XLA cluster with at least one GPU operation - // (which would normally force the cluster to be compiled using XLA:GPU) - // into an XLA cluster with no GPU operations (i.e. containing only CPU - // operations). Such a cluster can fail compilation (in way that - // MarkForCompilation could not have detected) if the CPU JIT is not linked - // in. - // - // So bail out of _XlaCompile in this case, and let the executor handle the - // situation for us. - const Status& status = compiler_for_platform.status(); - if (status.code() == error::NOT_FOUND) { - return errors::Unimplemented("Could not find compiler for platform ", - platform.ValueOrDie()->Name(), ": ", - status.ToString()); - } - } - - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - platform_info.device_type().type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, + absl::Span inputs, absl::Span variable_infos, absl::Span constants, bool lazy, bool may_alias_resource_update, xla::LocalClient** client, @@ -292,7 +181,7 @@ static Status CompileToLocalExecutable( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", &cache, [&](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, platform_info, cache); + return BuildXlaCompilationCache(ctx->device(), platform_info, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but @@ -302,32 +191,11 @@ static Status CompileToLocalExecutable( *client = static_cast(cache->client()); absl::optional tf_allocator_adapter; - XlaCompiler::Options options; - options.client = *client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = - (platform_info.platform_id() == se::host::kHostPlatformId); - options.device_allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info); - if (platform_info.xla_device_metadata()) { - options.shape_representation_fn = - platform_info.xla_device_metadata()->shape_representation_fn(); - } - // If reference variables are not present in the graph, we can safely alias - // passthrough parameters without performing a copy. - options.alias_passthrough_params = - !has_ref_vars && !platform_info.is_on_xla_device(); + XlaCompiler::Options options = GenerateCompilerOptions( + *cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info, has_ref_vars, &tf_allocator_adapter); - std::map constant_args; - for (int i : constants) { - constant_args.insert({i, ctx->input(i)}); - } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; // Optimization: where possible, have the computation return a naked array @@ -337,10 +205,11 @@ static Status CompileToLocalExecutable( !platform_info.is_on_xla_device() && may_alias_resource_update; - std::vector args; - TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_args, variable_infos, ctx, &args)); - return cache->Compile(options, function, args, compile_options, + xla::StatusOr> args = + XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs, + variable_infos); + TF_RETURN_IF_ERROR(args.status()); + return cache->Compile(options, function, *args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, compilation_result, executable); @@ -350,6 +219,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + std::vector inputs = InputsFromContext(ctx); xla::LocalClient* client; const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; @@ -357,10 +227,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { std::vector variable_infos; { OP_REQUIRES_OK( - ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( - ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, + ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs, variable_infos, constants_, /*lazy=*/false, /*may_alias_resource_update=*/true, &client, &compilation_result, &executable); @@ -378,8 +249,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); int device_ordinal = stream ? stream->parent()->device_ordinal() : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( @@ -503,7 +376,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)), - platform_info_(PlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), must_compile_(MustCompileAttr(ctx)), has_ref_vars_(HasRefVars(ctx)) {} @@ -515,6 +388,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; ResourceVarsSnapshot variables; + std::vector inputs = InputsFromContext(ctx); bool cannot_compile_cluster; { mutex_lock guard(cannot_compile_cluster_mu_); @@ -527,13 +401,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } else { std::vector variable_infos; OP_REQUIRES_OK( - ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); // Do not alias resource updates as locking variables in XlaCompile and // unlocking them in XlaRun may lead to deadlocks. Status status = CompileToLocalExecutable( - ctx, function_, has_ref_vars_, platform_info_, variable_infos, + ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos, constants_, /*lazy=*/!must_compile_, /*may_alias_resource_update=*/false, &client, &kernel, &executable); @@ -591,7 +466,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); @@ -602,8 +477,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Consume(key); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; int device_ordinal = stream ? stream->parent()->device_ordinal() diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 112408226a8..78707c8126d 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -31,61 +32,6 @@ limitations under the License. namespace tensorflow { -// Holds some information about the platform on which an -// XlaLaunch/_XlaCompile/_XlaRun op must run on. -class XlaPlatformInfo { - public: - XlaPlatformInfo() : device_type_("") {} - XlaPlatformInfo(XlaPlatformInfo&&) = default; - explicit XlaPlatformInfo(const DeviceType device_type, - se::Platform::Id platform_id, - const XlaDevice::Metadata* xla_device_metadata, - se::DeviceMemoryAllocator* device_allocator) - : device_type_(device_type), - platform_id_(platform_id), - xla_device_metadata_(xla_device_metadata), - device_allocator_(device_allocator) {} - - XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; - - bool UseMultipleStreams() const { - return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); - } - - // Non-null only when run on an XLA device. - se::DeviceMemoryAllocator* custom_allocator() const { - return device_allocator_; - } - - DeviceType device_type() const { return device_type_; } - - // This is equal to xla_device_metadata()->platform()->id() if - // xla_device_metadata() is not nullptr. - se::Platform::Id platform_id() const { return platform_id_; } - - // This may be null if the op this XlaPlatformInfo is for was not placed on an - // XLA device. - const XlaDevice::Metadata* xla_device_metadata() const { - return xla_device_metadata_; - } - bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } - - private: - DeviceType device_type_; - se::Platform::Id platform_id_; - - // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the - // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the - // XlaLaunch/_XlaCompile/_XlaRun OpKernel. - const XlaDevice::Metadata* xla_device_metadata_; - - // If the op associated with this XlaPlatformInfo is placed on an XLA device - // then device_allocator_ is the xla::Backend's memory allocator. If the op - // is placed on a regular CPU or GPU device then device_allocator_ is null. - se::DeviceMemoryAllocator* device_allocator_; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); -}; // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. // The only difference is that it does not require arguments to follow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 19eb61b6f72..81403fbf2dc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" @@ -1196,12 +1196,9 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - DeviceType jit_device_type(registration->compilation_device_name); - - RecursiveCompilabilityChecker::OperationFilter op_filter = - CreateOperationFilter(*registration); - - if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type} + if (!RecursiveCompilabilityChecker{ + CreateOperationFilter(*registration), + DeviceType{registration->compilation_device_name}} .IsCompilableNode(*node, lib_runtime)) { continue; } @@ -1718,7 +1715,6 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, const XlaOpRegistry::DeviceRegistration* registration; CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); - DeviceType jit_device_type(registration->compilation_device_name); // We can always *compile* resource operations, stateful RNGs and dummy ops, // even if we are sometimes unable to auto-cluster them. @@ -1733,7 +1729,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, op_filter.allow_slow_ops = true; op_filter.allow_inaccurate_ops = true; - RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type}; + RecursiveCompilabilityChecker checker{ + op_filter, DeviceType{registration->compilation_device_name}}; if (!uncompilable_node_info) { // We do not need uncompilable node info. Just return the result. return checker.IsCompilableCall(ndef, flr); @@ -1837,7 +1834,9 @@ absl::flat_hash_map>* GetAllowlistTable() { "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}}; + "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", + "TensorStridedSliceUpdate", + }}}; // clang-format on return result; } @@ -1996,6 +1995,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "ResourceScatterNdUpdate", "ResourceScatterSub", "ResourceScatterUpdate", + "RngReadAndSkip", + "RngSkip", "Roll", "ScatterNd", "SelfAdjointEigV2", @@ -2018,11 +2019,17 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "StatelessCase", "StatelessIf", "StatelessMultinomial", + "StatelessRandomGetKeyCounterAlg", "StatelessRandomNormal", + "StatelessRandomNormalV2", "StatelessRandomUniform", + "StatelessRandomUniformV2", "StatelessRandomUniformInt", + "StatelessRandomUniformIntV2", "StatelessRandomUniformFullInt", + "StatelessRandomUniformFullIntV2", "StatelessTruncatedNormal", + "StatelessTruncatedNormalV2", "StatelessWhile", "Svd", "SymbolicGradient", @@ -2080,6 +2087,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSelectAndScatter", "XlaSelfAdjointEig", "XlaSend", + "XlaSetBound", "XlaSharding", "XlaSort", "XlaSpmdFullToShardShape", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index e88319bb732..1be3e5ba9e7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -44,6 +44,11 @@ using ::tensorflow::testing::FindNodeByName; namespace tensorflow { namespace { +static bool Initialized = [] { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + return true; +}(); + REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 7378d17f88d..87c9fbf0af7 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -406,37 +406,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); } -TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output dynamic_slice_operand = - ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32, - ops::Placeholder::Attrs{}); - Output dynamic_slice_begin = ops::Placeholder( - s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{}); - Output dynamic_slice_size = ops::Placeholder( - s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{}); - Output dynamic_slice = - ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand, - dynamic_slice_begin, dynamic_slice_size); - - Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), - DT_FLOAT, ops::Placeholder::Attrs{}); - Output reshape = - ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice); - - AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0"); - - std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); - TF_ASSERT_OK(s.ToGraph(graph.get())); - - Node* n = FindNodeByName(*graph, "dynamic_slice"); - ASSERT_NE(n, nullptr); - - TF_ASSERT_OK(PartiallyDecluster(&graph)); - - EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); -} - TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { const char* const kClusteredProducer0Name = "ClusteredProducer0"; const char* const kClusteredProducer1Name = "ClusteredProducer1"; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index b1525337dbc..fb184d62e27 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -20,9 +20,11 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -278,25 +280,23 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool are_args_supported = - absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kConstant || - arg.kind == XlaCompiler::Argument::kParameter; + bool has_tensor_list_arg = + absl::c_any_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kTensorList; }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // TODO(b/155596779): Understand the source of other argument types and - // depending on the source either support those or avoid these codepath. - if (!use_mlir || !are_args_supported) { + // TODO(b/155596779): Support TensorList args. + if (!use_mlir || !has_tensor_list_arg) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } GraphDebugInfo debug_info; return CompileGraphToXlaHlo( - *graph, {args.data(), args.size()}, options.device_type.type_string(), - compile_options.use_tuple_arg, *options.flib_def, debug_info, - options.shape_representation_fn, result); + *graph, mlir::SpanToArrayRef(args), + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, @@ -325,6 +325,10 @@ Status XlaCompilationCache::CompileImpl( absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { + if (FailOnXlaCompilation()) { + return errors::Internal("XLA compilation disabled"); + } + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc index 7227615d2bb..5578925b790 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) { } } +TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { + NameAttrList fn; + fn.set_name("afunction"); + + DisableXlaCompilation(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT); + + const XlaCompiler::CompilationResult* compilation_result; + xla::LocalExecutable* executable; + + auto cache = new XlaCompilationCache(client, device_type); + core::ScopedUnref cache_ref(cache); + + Status status = cache->Compile(XlaCompiler::Options{}, fn, {}, + XlaCompiler::CompileOptions{}, + XlaCompilationCache::CompileMode::kStrict, + &compilation_result, &executable); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "XLA compilation disabled")); +} + static void BM_BuildSignature(int iters, int n_args) { NameAttrList fn; fn.set_name("afunction"); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 50813859603..d092508eccf 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,18 +42,23 @@ static std::vector GetResourceVariableIndices(OpKernelContext* ctx) { } Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, - const XlaDevice::Metadata& metadata, + XlaCompilationCache* cache, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args) { - xla::LocalClient* client = metadata.client(); + xla::LocalClient* client = static_cast(cache->client()); - // Builds an XLA allocator for the device. + absl::optional tf_allocator_adapter; + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); XlaComputationLaunchContext launch_context( - client, client->backend().memory_allocator(), - client->default_device_ordinal(), - /*allocate_xla_tensors=*/true, - /*use_multiple_streams=*/metadata.UseMultipleStreams()); + client, allocator, client->default_device_ordinal(), + /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, + platform_info_.xla_device_metadata() + ? platform_info_.xla_device_metadata()->UseMultipleStreams() + : false); std::map snapshot_ptrs; for (auto& p : variable_args) { @@ -70,12 +76,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - TF_RET_CHECK(stream); VLOG(2) << "Executing computation: " << name(); xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(client->backend().memory_allocator()); + run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); @@ -94,98 +99,39 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, return Status::OK(); } -Status XlaCompileOnDemandOp::MustArgumentBeConstant( - const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, bool* result) { - *result = false; +Status XlaCompileOnDemandOp::Compile( + OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, + XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, + xla::LocalExecutable** executable) { - // TODO(jmolloy): This could be expensive, so memoize. std::vector constant_input_indices; TF_RETURN_IF_ERROR(GetCompileTimeConstInputs( - op_kernel, &constant_input_indices, flib_runtime)); - *result = absl::c_binary_search(constant_input_indices, argument_idx); - return Status::OK(); -} - -// TODO(ycao): Remove the need to call ShouldArgumentBeConstant. Its benefit is -// not clear yet and it causes heavy constant analysis to run twice. -Status XlaCompileOnDemandOp::ShouldArgumentBeConstant( - const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, bool* result) { - return MustArgumentBeConstant(op_kernel, argument_idx, flib_runtime, result); -} - -Status XlaCompileOnDemandOp::Compile( - OpKernelContext* ctx, const XlaDevice::Metadata& metadata, - const XlaCompiler::CompilationResult** result, - ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) { - std::map constant_arguments; - for (int64 i = 0; i < ctx->num_inputs(); ++i) { - const Tensor& device_tensor = ctx->input(i); - if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { - if (xla_tensor->has_host_tensor()) { - bool should_arg_be_const; - TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i, - ctx->function_library(), - &should_arg_be_const)); - if (should_arg_be_const) { - constant_arguments[i] = xla_tensor->host_tensor(); - } - } - } - - if (constant_arguments.count(i) == 0) { - bool must_argument_be_const; - TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i, - ctx->function_library(), - &must_argument_be_const)); - - if (must_argument_be_const) { - // Slow path; the argument is not available as a host constant so we - // must fetch it synchronously. - Tensor host_tensor; - AllocatorAttributes attrs; - attrs.set_on_host(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp( - device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); - Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync( - &device_tensor, "ConstantArgument", - reinterpret_cast(ctx->device()), &host_tensor); - if (!status.ok()) { - LOG(ERROR) << "Copying tensor of shape " - << device_tensor.shape().DebugString() << " from " - << ctx->device()->name() << "to CPU failed with " - << status.ToString(); - return status; - } - constant_arguments[i] = host_tensor; - } - } + &ctx->op_kernel(), &constant_input_indices, ctx->function_library())); + if (!absl::c_all_of(constant_input_indices, [&](int idx) { + return ctx->input_memory_type(idx) == HOST_MEMORY; + })) { + return errors::Internal("Unexpected device placement for a constant input"); } + std::vector inputs = InputsFromContext(ctx); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); CHECK(rm); - XlaCompilationCache* cache; TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [&](XlaCompilationCache** cache) { - *cache = new XlaCompilationCache(metadata.client(), - metadata.jit_device_type()); - return Status::OK(); + rm->default_container(), "xla_cache", cache, + [&](XlaCompilationCache** write_into_cache) { + return BuildXlaCompilationCache(ctx->device(), platform_info_, + write_into_cache); })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - XlaCompiler::Options options; - options.device_type = metadata.jit_device_type(); - options.client = metadata.client(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.shape_representation_fn = metadata.shape_representation_fn(); + absl::optional tf_allocator_adapter; + XlaCompiler::Options options = GenerateCompilerOptions( + **cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_, + /*has_ref_vars=*/true, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -194,31 +140,41 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::vector variables_indices = GetResourceVariableIndices(ctx); - std::vector args; + xla::StatusOr> args; { std::vector variable_infos; TF_RETURN_IF_ERROR( - GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos)); + GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, variables_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); TF_RETURN_IF_ERROR(SnapshotResourceVariables( ctx, variables_indices, variable_infos, variable_args)); - TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arguments, variable_infos, ctx, &args)); + + args = XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_input_indices, inputs, variable_infos); + TF_RETURN_IF_ERROR(args.status()); } - return cache->CompileSingleOp(options, args, ctx, compile_options, result, - executable); + return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* result; xla::LocalExecutable* executable; - const XlaDevice::Metadata* metadata; - OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); ResourceVarsSnapshot variable_args; + XlaCompilationCache* cache; + OP_REQUIRES(ctx, ctx->function_library(), + errors::Internal("Function library missing")); OP_REQUIRES_OK(ctx, - Compile(ctx, *metadata, &result, &variable_args, &executable)); - OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args)); + Compile(ctx, &result, &cache, &variable_args, &executable)); + + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args)); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index cc5f2f1e42f..bb8ab889ce9 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/function.h" @@ -35,25 +36,25 @@ namespace tensorflow { // vanilla TensorFlow op as long as the bridge supports it. class XlaCompileOnDemandOp : public OpKernel { public: - explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void Compute(OpKernelContext* ctx) override; private: XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); - Status ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, - bool* result); - Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, - bool* result); - Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + Status Compile(OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, + XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable); - Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + + Status Run(OpKernelContext* ctx, XlaCompilationCache* cache, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args); + + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 446cd8944de..dd1ddb616f5 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -51,7 +51,7 @@ Status XlaCpuDeviceFactory::CreateDevices( std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } bool compile_on_demand = flags->tf_xla_compile_on_demand; diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 7842513331d..089d22dca03 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -61,6 +61,21 @@ limitations under the License. namespace tensorflow { +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + // Caches a XlaDeviceAllocator per pair. A // XlaDeviceAllocator is created on demand and is associated with a // XlaDevice. It outlives the device itself (for instance, the buffer @@ -116,20 +131,6 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( namespace { -// Default PaddedShapeFn implementation that simply returns the unpadded -// on-device shape. This is accurate for CPU and GPU devices that neither -// transpose nor pad tensors. -Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { - const tensorflow::XlaTensor* xla_tensor = - tensorflow::XlaTensor::FromTensor(&tensor); - if (xla_tensor == nullptr) { - return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); - } - - const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); - *shape = shaped_buffer.on_device_shape(); - return Status::OK(); -} static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix, const string& device_name, @@ -572,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by an XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - OpKernel* (*factory)(OpKernelConstruction*) = - [](OpKernelConstruction* context) -> OpKernel* { + auto factory = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; XlaOpRegistry::RegisterCompilationKernels(); @@ -582,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); + const std::unordered_set* constant_inputs = + XlaOpRegistry::CompileTimeConstantInputArgNames(def->op()); + + for (const std::string& arg_name : *constant_inputs) { + def->add_host_memory_arg(arg_name); + } + def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 30f9a99e36a..6d6086ce0fa 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice { static Status GetMetadata(OpKernelConstruction* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by + // `device`. + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + struct Options { // The StreamExecutor platform. Not owned. Must be non-null. se::Platform* platform = nullptr; @@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice { xla::StatusOr> GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetMetadataFromDevice(DeviceBase* device, - const XlaDevice::Metadata** metadata); Status MakeTensorFromProto(XlaDeviceContext* device_context, const TensorProto& tensor_proto, @@ -280,6 +283,8 @@ struct XlaDeviceOpRegistrations { XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device); +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 16f496d51a3..99ba5658819 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -66,7 +66,7 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc deleted file mode 100644 index f720183e196..00000000000 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" - -namespace tensorflow { - -const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; -const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; - -constexpr std::array kExecAllTypes = { - {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; - -class XlaInterpreterDeviceFactory : public DeviceFactory { - public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; -}; - -Status XlaInterpreterDeviceFactory::ListPhysicalDevices( - std::vector* devices) { - devices->push_back( - absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0")); - - return Status::OK(); -} - -Status XlaInterpreterDeviceFactory::CreateDevices( - const SessionOptions& session_options, const string& name_prefix, - std::vector>* devices) { - static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( - DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); - (void)registrations; - - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.autoclustering_policy = - XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.cluster_resource_variable_ops_unsafely = true; - registration.cluster_stack_ops = false; - registration.cluster_tensor_array_ops = true; - registration.cluster_stateful_rng_ops = true; - registration.cluster_control_trigger = true; - registration.elide_assert_and_checknumerics = true; - registration.cluster_variant_ops = true; - registration.cluster_slow_ops = true; - registration.cluster_inaccurate_ops = true; - XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, - registration); - - TF_ASSIGN_OR_RETURN( - auto platform, se::MultiPlatformManager::PlatformWithName("Interpreter")); - - XlaDevice::Options options; - options.platform = platform; - options.device_name_prefix = name_prefix; - options.device_name = DEVICE_XLA_INTERPRETER; - options.device_ordinal = 0; - options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); - - return Status::OK(); -} - -// Set priority to be below the default priority (50), so that Interpreter is -// not selected as a high priority device over other default devices. See -// constructor comments for Registrar in -// tensorflow/core/common_runtime/device_factory.h for a list of priority for -// devices. -REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER, - XlaInterpreterDeviceFactory, 40); - -// Kernel registrations -static bool OpFilter(KernelDef* kdef) { return true; } - -REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, - kExecAllTypes); -REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, - kExecAllTypes); -REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); - -REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); -REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 5ca146969e0..7387978fbcd 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -14,10 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_kernel_creator.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -27,6 +38,78 @@ bool XlaKernelCreator::CanCreateKernel( return CanCreateXlaKernel(props->node_def); } +static Status CreateXlaKernel(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + std::unique_ptr* kernel) { + if (!CanCreateXlaKernel(node_def)) { + return errors::Internal("Invalid node: ", node_def.ShortDebugString()); + } + + VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + + // Only check for compilability if the MLIR bridge is not enabled. + if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; + if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { + std::vector + uncompilable_node_info; + for (const auto& it : uncompilable_nodes_map) { + for (const auto& info : it.second.second) { + uncompilable_node_info.emplace_back(info); + } + } + string message = absl::StrCat( + "Function invoked by the following node is not compilable: ", + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); + absl::StrAppend(&message, "Uncompilable nodes:"); + for (const auto& node_info : uncompilable_node_info) { + string node_message = absl::StrCat("\n", node_info.name, ": ", + node_info.uncompilable_reason, "\n", + "\tStacktrace:\n"); + for (const auto& stack_frame : node_info.stack_trace) { + absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", + stack_frame.name, stack_frame.function_name); + } + absl::StrAppend(&message, node_message); + } + VLOG(1) << message; + return errors::InvalidArgument(message); + } + } + + // Get function body, constant args, and resource args. + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); + + MemoryTypeVector input_memory_types = + GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); + MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + + // Create the kernel. + Device* dev = flr->device(); + Status s; + auto props = std::make_shared( + &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); + OpKernelConstruction construction(DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), + flr, dev->resource_manager(), props, + input_memory_types, output_memory_types, + flr->graph_def_version(), &s); + + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function, + /*has_ref_vars=*/false); + return s; +} + Status XlaKernelCreator::CreateKernel( FunctionLibraryRuntime* flr, const std::shared_ptr& props, @@ -34,19 +117,12 @@ Status XlaKernelCreator::CreateKernel( return CreateXlaKernel(flr, props->node_def, kernel); } -namespace { - -bool RegisterLaunchOpCreator() { +static bool RegisterLaunchOpCreator() { XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator(); RegisterDefaultCustomKernelCreator(xla_kernel_creator); return true; } static bool register_me = RegisterLaunchOpCreator(); -static bool register_xla = [] { - SetXlaIsEnabled(); - return true; -}(); -} // end namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc deleted file mode 100644 index 61c89d8a67a..00000000000 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/jit/compilability_check_util.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/ptr_util.h" - -namespace tensorflow { -namespace { - -// Utility which searches for values in a sorted list by scanning over it once. -// No matter how many times ScanForValue is called, the list is scanned at most -// once. However, if a call to ScanForValue skips over a value, that value is -// not revisited in future calls to ScanForValue, so callers must take -// care to order their calls. -// -// Useful for merging multiple sorted lists in O(n) time. -class SinglePassSearch { - public: - // Creates a SinglePassSearch object that can be used to search in `values`. - // Does not take ownership of `values`. `values` must outlive this. - // `values` must be sorted. - explicit SinglePassSearch(const std::vector* values) - : current_index_(0), values_(values) {} - - // Scans forward in the vector looking for "value", updating the internal - // position in to the vector. - // Returns true iff the vector contains the given value at or after current - // position. - // Not thread-safe. - bool ScanForValue(int value) { - while (current_index_ < values_->size() && - (*values_)[current_index_] <= value) { - if ((*values_)[current_index_] == value) { - current_index_++; - return true; - } - current_index_++; - } - return false; - } - - private: - int current_index_; - const std::vector* values_; -}; -} // namespace - -Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel) { - if (!CanCreateXlaKernel(node_def)) { - return errors::Internal("Invalid node: ", node_def.ShortDebugString()); - } - - VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - - // Only check for compilability if the MLIR bridge is not enabled. - if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); - } - } - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = absl::StrCat("\n", node_info.name, ": ", - node_info.uncompilable_reason, "\n", - "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - return errors::InvalidArgument(message); - } - } - - // Get function body, constant args, and resource args. - const FunctionBody* fbody = nullptr; - std::vector constant_arg_indices; - std::vector resource_arg_indices; - TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); - - // Set input and output memory types. - MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); - // These indices are used only for optimization purposes. They allow us - // to loop over constant_arg_indices and resource_arg_indices only once - // while iterating over all the function arguments checking if it is a - // resource or a constant. - // The reason we optimized this code is because functions can have a lot of - // captured arguments. For example, the backward pass of ResNet50 takes in all - // 214 variables and a similar number of activations. - SinglePassSearch constants_search(&constant_arg_indices); - SinglePassSearch resources_search(&resource_arg_indices); - for (size_t i = 0; i < fbody->arg_types.size(); ++i) { - if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { - // Compile-time constants and resource handles are expected to be in - // host memory. - input_memory_types[i] = HOST_MEMORY; - } - } - // One might wonder, about the case where a compile-time constant argument - // (which must be in host memory) is also used as an input into an op, - // e.g. Add, that expects its inputs in device memory. Here is how it - // works now. - // First, what do we mean by "op expects an input in XYZ memory"? - // There are two types of "ops" here: the tf2xla kernel and the HLO - // computation it builds. The tf2xla kernel needs to retrieve the actual - // numeric value of the compile-time constant tensors, so it really expects - // them to be on in host memory. However, for other inputs, it refers to them - // using xla::ComputationDataHandle, which is just a symbolic handle that - // xla::ComputationBuilder assigns. How does this handle gets assigned for - // constant arguments? Even constant arguments get an _Arg node in the graph - // instantiated for Function compilation. The tf2xla kernel for constant _Arg - // nodes takes the constant value, converts it to XlaLiteral, and feeds it - // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This - // constant XlaLiteral is included in the HLO graph, and subsequently, in - // the actual executable, which is copied to the device before being - // executed. Thus, when this executable runs, the constant is available in - // device memory. - - // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory except for resources. - MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); - for (size_t i = 0; i < fbody->ret_types.size(); ++i) { - if (fbody->ret_types[i] == DT_RESOURCE) { - output_memory_types[i] = HOST_MEMORY; - } - } - - // Create the kernel. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); - Device* dev = flr->device(); - Status s; - auto props = std::make_shared( - &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); - OpKernelConstruction construction(DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), - flr, dev->resource_manager(), props, - input_memory_types, output_memory_types, - flr->graph_def_version(), &s); - - *kernel = absl::make_unique( - &construction, constant_arg_indices, resource_arg_indices, function, - /*has_ref_vars=*/false); - return s; -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 19e2b5a2bb5..a8b090f1450 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -44,12 +44,6 @@ namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; -const char kPossibleNonVariableResourceHintMessage[] = - "If the error is similar to `Trying to access resource using the wrong " - "type`, this is likely because XLA only accepts Resource Variables as " - "inputs by snapshotting their values. Other TensorFlow resource types like " - "TensorList/TensorArray/Stack are not supported. Try removing non-variable " - "resource inputs to XLA."; } // anonymous namespace VariableInfo::VariableInfo(int index, absl::string_view name, Var* var) @@ -85,19 +79,22 @@ VariableInfo::~VariableInfo() { } } -// Returns a vector of VariableInfo instances for the resource variable inputs -// to the kernel with context `ctx`. The input indices for the resource -// variable inputs are in `variable_indices`. -Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, - absl::Span variable_indices, - std::vector* result) { +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result) { result->clear(); result->reserve(variable_indices.size()); for (int var_idx : variable_indices) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, var_idx); - TF_RETURN_IF_ERROR( - LookupOrCreateResource(ctx, handle, &variable, [&](Var** ptr) { + ResourceHandle handle = inputs[var_idx]->flat()(0); + if (handle.device() != dev->attributes().name()) { + return errors::InvalidArgument("Trying to access resource ", + handle.name(), " located in device ", + dev->name()); + } + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + handle.container(), handle.name(), &variable, [](Var** ptr) { // This var is uninitialized for now. *ptr = new Var(DT_INVALID); return Status::OK(); @@ -107,6 +104,15 @@ Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, return Status::OK(); } +std::vector InputsFromContext(OpKernelContext* ctx) { + std::vector inputs; + inputs.reserve(ctx->num_inputs()); + for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) { + inputs.push_back(&ctx->input(input_idx)); + } + return inputs; +} + Status LockVariables(absl::Span variables) { std::vector lock_order(variables.size()); std::iota(lock_order.begin(), lock_order.end(), 0); @@ -358,9 +364,6 @@ static Status SetOutputForConstant( ctx->set_output(output_num, const_tensor); output_tensor = ctx->mutable_output(output_num); } - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - xla_tensor->set_host_tensor(const_tensor); - } return Status::OK(); } @@ -557,11 +560,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( return Status::OK(); } -Status XlaComputationLaunchContext::BuildXlaCompilerArguments( - const std::map& must_be_constant_args, - absl::Span variable_args, OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); +xla::StatusOr> +XlaComputationLaunchContext::BuildXlaCompilerArguments( + absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args) { + CHECK(absl::c_is_sorted(must_be_constant_idxs)); + std::vector out; + out.resize(inputs.size()); absl::flat_hash_map variable_info_lookup; for (const VariableInfo& info : variable_args) { @@ -571,33 +577,20 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( variable_info_lookup.emplace(info.index(), &info); } - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; + for (int64 input_num = 0; input_num < inputs.size(); ++input_num) { + const Tensor* input = inputs[input_num]; - if (must_be_constant_args.count(input_num) > 0) { + XlaCompiler::Argument& arg = out[input_num]; + if (absl::c_binary_search(must_be_constant_idxs, input_num)) { // Handles compile-time constants. - const Tensor& input = must_be_constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + TF_RET_CHECK(input->dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_info_lookup.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { + arg.type = input->dtype(); + arg.shape = input->shape(); + arg.constant_value = *input; + } else if (variable_info_lookup.count(input_num)) { // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); + TF_RET_CHECK(input->dtype() == DT_RESOURCE); const VariableInfo& variable = *variable_info_lookup[input_num]; arg.name = std::string(variable.name()); arg.kind = XlaCompiler::Argument::kResource; @@ -616,10 +609,21 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.type = DT_INVALID; arg.shape = TensorShape(); } + } else { + // Normal inputs. + TF_RET_CHECK(input->dtype() != DT_RESOURCE); + if (input->NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *input; + } + arg.type = input->dtype(); + arg.shape = input->shape(); } } - return Status::OK(); + return out; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index b34b3059a4f..ac085a022c8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -109,12 +109,16 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); -// Returns a vector of VariableInfo instances for the resource variable inputs -// to the kernel with context `ctx`. The input indices for the resource +// Returns a vector of VariableInfo instances for the resource variable inputs, +// given that *all* inputs are in `inputs`. The input indices for the resource // variable inputs are in `variable_indices`. -Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, - absl::Span variable_indices, - std::vector* result); +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result); + +// Returns pointers to inputs stored in `ctx`. +std::vector InputsFromContext(OpKernelContext* ctx); // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. @@ -136,10 +140,10 @@ class XlaComputationLaunchContext { // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch // op. // Precondition: variables in `variable_args` are locked. - static Status BuildXlaCompilerArguments( - const std::map& constant_args, - absl::Span variable_args, OpKernelContext* ctx, - std::vector* args); + static xla::StatusOr> + BuildXlaCompilerArguments(absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc new file mode 100644 index 00000000000..6c6c490e032 --- /dev/null +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Register XlaXXX operations on regular CPU/GPU devices using +// `XlaCompileOnDemandOp`. +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +#define REGISTER_XLA_OPS_ON_DEVICE(DEVICE) \ + REGISTER_KERNEL_BUILDER(Name("XlaConv") \ + .HostMemory("window_strides") \ + .HostMemory("padding") \ + .HostMemory("lhs_dilation") \ + .HostMemory("rhs_dilation") \ + .HostMemory("feature_group_count") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSelfAdjointEig").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSvd").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaPad") \ + .HostMemory("padding_low") \ + .HostMemory("padding_high") \ + .HostMemory("padding_interior") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow") \ + .HostMemory("window_dimensions") \ + .HostMemory("window_strides") \ + .HostMemory("base_dilations") \ + .HostMemory("window_dilations") \ + .HostMemory("padding") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \ + .HostMemory("window_dimensions") \ + .HostMemory("window_strides") \ + .HostMemory("padding") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSend").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSort").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaGather").HostMemory("slice_sizes").Device(DEVICE), \ + XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \ + XlaCompileOnDemandOp); + +REGISTER_XLA_OPS_ON_DEVICE(DEVICE_CPU); +REGISTER_XLA_OPS_ON_DEVICE(DEVICE_GPU); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc new file mode 100644 index 00000000000..b38bf9282b1 --- /dev/null +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -0,0 +1,158 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_platform_info.h" + +#include "tensorflow/compiler/xla/client/client_library.h" + +namespace tensorflow { + +Status BuildXlaCompilationCache(DeviceBase* device, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + + xla::StatusOr compiler_for_platform = + xla::Compiler::GetForPlatform(platform.ValueOrDie()); + if (!compiler_for_platform.ok()) { + // In some rare cases (usually in unit tests with very small clusters) we + // may end up transforming an XLA cluster with at least one GPU operation + // (which would normally force the cluster to be compiled using XLA:GPU) + // into an XLA cluster with no GPU operations (i.e. containing only CPU + // operations). Such a cluster can fail compilation (in way that + // MarkForCompilation could not have detected) if the CPU JIT is not linked + // in. + // + // So bail out of _XlaCompile in this case, and let the executor handle the + // situation for us. + const Status& status = compiler_for_platform.status(); + if (status.code() == error::NOT_FOUND) { + return errors::Unimplemented("Could not find compiler for platform ", + platform.ValueOrDie()->Name(), ": ", + status.ToString()); + } + } + + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + device->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { + auto device = static_cast(device_base); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + se::DeviceMemoryAllocator* custom_allocator = nullptr; + + if (device->device_type() == DEVICE_CPU) { + platform_id = se::host::kHostPlatformId; + } else if (device->device_type() == DEVICE_GPU) { + platform_id = device->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata) + .ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + platform_id = xla_device_metadata->platform()->id(); + custom_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + return XlaPlatformInfo(DeviceType(device->device_type()), platform_id, + xla_device_metadata, custom_allocator); +} + +se::DeviceMemoryAllocator* GetAllocator( + absl::optional* tf_allocator_adapter, + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info) { + if (platform_info.custom_allocator()) { + return platform_info.custom_allocator(); + } + if (!stream) { + // Stream is not set for the host platform. + se::Platform* platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) + .ValueOrDie(); + tf_allocator_adapter->emplace(device->GetAllocator({}), platform); + return &tf_allocator_adapter->value(); + } + tf_allocator_adapter->emplace(device->GetAllocator({}), stream); + return &tf_allocator_adapter->value(); +} + +XlaCompiler::Options GenerateCompilerOptions( + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, + absl::optional* tf_allocator_adapter) { + XlaCompiler::Options options; + options.client = static_cast(cache.client()); + if (stream != nullptr) { + options.device_ordinal = stream->parent()->device_ordinal(); + } + options.device_type = cache.device_type(); + options.flib_def = function_library.GetFunctionLibraryDefinition(); + options.graph_def_version = function_library.graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = + GetAllocator(tf_allocator_adapter, device, stream, platform_info); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + // If reference variables are not present in the graph, we can safely alias + // passthrough parameters without performing a copy. + options.alias_passthrough_params = + !has_ref_vars && !platform_info.is_on_xla_device(); + return options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h new file mode 100644 index 00000000000..bfb438cc398 --- /dev/null +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -0,0 +1,112 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/stream_executor/tf_allocator_adapter.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of +// abstraction for normal and XLA devices. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + se::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + device_allocator_(device_allocator) {} + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + // Non-null only when run on an XLA device. + se::DeviceMemoryAllocator* custom_allocator() const { + return device_allocator_; + } + + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator. If the op + // is placed on a regular CPU or GPU device then device_allocator_ is null. + se::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// Returns created XLA compilation cache. +Status BuildXlaCompilationCache(DeviceBase* dev, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache); + +// Returns information about the platform from kernel context. +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); + +// Returns allocator from platform info if non-null, or populate and return a +// pointer to the allocator adapter with allocator from context. +// +// This is necessary because for XLA devices the underlying TF allocator returns +// dummy tensors. +// +// `stream` parameter is nullable when running on host. +se::DeviceMemoryAllocator* GetAllocator( + absl::optional* tf_allocator_adapter, + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info); + +// Returns created options for the XLA compiler, and writes the used allocator +// into `tf_allocator_adapter`. +XlaCompiler::Options GenerateCompilerOptions( + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, + absl::optional* tf_allocator_adapter); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index dc358760534..2da1501819c 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -71,18 +71,6 @@ class XlaTensor { shaped_buffer_ = std::move(shaped_buffer); } - // Some tensors on the device may have known values on the host. We use these - // in on-demand mode to avoid re-copying values from the device if we know the - // host value already. - - // Return true if this XlaTensor contains a host tensor. - bool has_host_tensor() const { return host_tensor_.has_value(); } - // Return the contained host tensor. - // REQUIRES: has_host_tensor() - const Tensor& host_tensor() const { return *host_tensor_; } - // Sets the contained host tensor. - void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); } - // Adds synchronization events to 'stream' that wait for this tensor to be // defined on 'stream'. Does nothing if the tensor is already defined on that // stream. diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 01c187790b7..b1870b15595 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -24,11 +24,40 @@ filegroup( srcs = glob(["**/*.td"]), ) +cc_library( + name = "string_container_utils", + hdrs = ["utils/string_container_utils.h"], + deps = [ + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "array_container_utils", + hdrs = ["utils/array_container_utils.h"], + deps = [ + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "name_utils", + srcs = ["utils/name_utils.cc"], + hdrs = ["utils/name_utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "op_or_arg_name_mapper", srcs = ["op_or_arg_name_mapper.cc"], hdrs = ["op_or_arg_name_mapper.h"], deps = [ + ":name_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -40,14 +69,14 @@ cc_library( srcs = ["tf_mlir_opt_main.cc"], deps = [ ":init_mlir", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:lib", - "//tensorflow/core/platform:logging", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:IR", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", + "@llvm-project//mlir:Shape", ], ) @@ -64,14 +93,13 @@ cc_library( # xla-legalize-tf-with-tf2xla pass. "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", @@ -127,11 +155,8 @@ tf_cc_binary( deps = [ ":passes", ":tf_mlir_opt_main", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", - "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing", ], ) @@ -141,10 +166,9 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", @@ -157,7 +181,7 @@ tf_cc_binary( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", @@ -168,3 +192,5 @@ filegroup( name = "litfiles", srcs = glob(["runlit*py"]), ) + +exports_files(["run_lit.sh"]) diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index edbf3663a89..1fa57babdae 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -43,10 +43,10 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): and specifying a default driver will abort the tests. features: [str], list of extra features to enable. """ - if driver != _default_driver: - fail("There is no present support for custom drivers. Please omit" + - " the driver parameter when running this test. If you require" + - " custom driver support, please file an issue to request it.") + + # Remove the default_driver from the data: it does not exist as a file and is + # just a placeholder from the copybara rewrite. + data = [d for d in data if d != _default_driver] # Disable tests on windows for now, to enable testing rest of all xla and mlir. native.py_test( diff --git a/tensorflow/compiler/mlir/hlo/.gitignore b/tensorflow/compiler/mlir/hlo/.gitignore new file mode 100644 index 00000000000..cc1696bf575 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/.gitignore @@ -0,0 +1,4 @@ +build +llvm-project +llvm-build + diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 9eee39894e4..0e167519263 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -17,6 +17,7 @@ package_group( "//learning/brain/experimental/mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/google/xla/mlir/...", + "//learning/deepmind/partir/...", "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", "//tensorflow/compiler/mlir/...", @@ -41,6 +42,7 @@ filegroup( "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", @@ -298,6 +300,7 @@ cc_library( ":lhlo_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", @@ -310,17 +313,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "hlo_dialect_force_registration", - srcs = ["lib/Dialect/mhlo/IR/dialect_registration.cc"], - deps = [ - ":hlo", - ":lhlo", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) - cc_library( name = "hlo_dialect_registration", srcs = ["lib/Dialect/mhlo/IR/init.cc"], @@ -341,6 +333,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -348,6 +341,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "mhlo_control_flow_to_scf", + srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], + deps = [ + ":hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "map_lmhlo_to_scalar_op", hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"], @@ -404,6 +413,7 @@ cc_library( cc_library( name = "lhlo_legalize_to_llvm", srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ ":lhlo", "@llvm-project//mlir:IR", @@ -419,7 +429,10 @@ cc_library( cc_library( name = "legalize_to_linalg", srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], deps = [ ":hlo", ":lhlo", @@ -438,9 +451,13 @@ cc_library( cc_library( name = "transform_unranked_hlo", srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], deps = [ ":hlo", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", @@ -458,6 +475,7 @@ cc_library( ":lhlo", ":map_lmhlo_to_scalar_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", @@ -476,9 +494,11 @@ cc_library( deps = [ ":lhlo", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -486,21 +506,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "lhlo_copy_removal", - srcs = ["lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], - deps = [ - ":lhlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - ], - alwayslink = 1, -) - cc_library( name = "hlo_legalize_to_lhlo", srcs = ["lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc"], @@ -681,7 +686,6 @@ cc_library( ], deps = [ ":hlo", - ":hlo_dialect_force_registration", ":lower_complex_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -732,6 +736,7 @@ cc_library( srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ + ":chlo_legalize_to_hlo_inc_gen", ":hlo", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", @@ -741,6 +746,25 @@ cc_library( ], ) +gentbl( + name = "chlo_legalize_to_hlo_inc_gen", + strip_include_prefix = "lib/Dialect/mhlo/transforms/", + tbl_outs = [ + ( + "-gen-rewriters", + "lib/Dialect/mhlo/transforms/generated_chlo_legalize_to_hlo.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + ], +) + cc_library( name = "test_passes", srcs = [ @@ -759,8 +783,6 @@ cc_library( ":lhlo_legalize_to_llvm", # build-cleaner: keep ":materialize_broadcasts", # build-cleaner: keep ":unfuse_batch_norm", # build-cleaner: keep - "@llvm-project//mlir:AffineToStandardTransforms", - "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", @@ -793,11 +815,11 @@ cc_library( ":legalize_to_linalg", ":legalize_to_standard", ":lhlo", - ":lhlo_copy_removal", ":lhlo_fuse_linalg", ":lhlo_legalize_to_affine", ":lhlo_legalize_to_gpu", ":lhlo_legalize_to_parallel_loops", + ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", ":sink_constants_to_control_flow", @@ -807,13 +829,6 @@ cc_library( ], ) -cc_library( - name = "register_all_passes", - srcs = ["lib/Dialect/mhlo/transforms/register_all_passes.cc"], - deps = [":all_passes"], - alwayslink = 1, -) - cc_binary( name = "mlir-hlo-opt", srcs = [ @@ -821,7 +836,8 @@ cc_binary( ], deps = [ ":all_passes", - ":hlo_dialect_registration", + ":hlo", + ":lhlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/CMakeLists.txt new file mode 100644 index 00000000000..c4e2ea123df --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/CMakeLists.txt @@ -0,0 +1,94 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cmake_minimum_required(VERSION 3.13.4) + +if(POLICY CMP0068) + cmake_policy(SET CMP0068 NEW) + set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) +endif() + +if(POLICY CMP0075) + cmake_policy(SET CMP0075 NEW) +endif() + +if(POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif() + +#------------------------------------------------------------------------------- +# Project setup and globals +#------------------------------------------------------------------------------- + +project(mlir-hlo LANGUAGES CXX C) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 14) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") + +#------------------------------------------------------------------------------- +# Options and settings +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# MSVC defaults +#------------------------------------------------------------------------------- + +if(MSVC) + add_compile_options( + $<$:/MD> + $<$:/MD> + $<$:/MD> + ) +endif() + +#------------------------------------------------------------------------------- +# MLIR/LLVM Configuration +#------------------------------------------------------------------------------- + +find_package(MLIR REQUIRED CONFIG) +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +if(LLVM_ENABLE_ZLIB) + find_package(ZLIB) +endif() + +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/) +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + +#------------------------------------------------------------------------------- +# Directory setup +#------------------------------------------------------------------------------- + +set(MLIR_HLO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLIR_HLO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +add_custom_target(check-mlir-hlo) + +add_subdirectory(include/mlir-hlo) +add_subdirectory(lib) +add_subdirectory(tools) +add_subdirectory(tests) diff --git a/tensorflow/compiler/mlir/hlo/README.md b/tensorflow/compiler/mlir/hlo/README.md index 1be6fb29d13..9eaa14031fd 100644 --- a/tensorflow/compiler/mlir/hlo/README.md +++ b/tensorflow/compiler/mlir/hlo/README.md @@ -1,4 +1,4 @@ -# MLIR-HLO +# MLIR-HLO: A Standalone "HLO" MLIR-based Compiler The code here exists in two places: @@ -22,10 +22,43 @@ upstream. ## QuickStart: building and testing -TODO +These instructions work on Linux, you may have to adjust for your plaform. + +To build the code in this repository, you need a clone of the LLVM/MLIR git +repository: + + $ git clone https://github.com/llvm/llvm-project.git + + +You need to make sure you have the right commit checked out in the LLVM +repository (you need to do this every time you pull from this repo): + + $ (cd llvm-project && git checkout $(cat build_tools/llvm_version.txt)) + +We provide a script to configure and build LLVM/MLIR: + + $ build_tools/build_mlir.sh ${PWD}/llvm-project/ ${PWD}/llvm-build + +Again this is something to do every time you pull from this repository and the +LLVM revision changes. + +Finally you can build and test this repository: + + $ mkdir build && cd build + $ cmake .. -GNinja \ + -DLLVM_ENABLE_LLD=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=On \ + -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir + $ ninja check-mlir-hlo + ## Overview +MLIR-HLO aims to provide an end-to-end compiler for CPU and GPU, as well as +building reusable blocks for other accelerators. This is heavily inspired by the +success of XLA. + [XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) is a domain-specific compiler framework and execution environment for linear algebra, which powers code-generation for ML frameworks like TensorFlow, JAX, and others. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt new file mode 100644 index 00000000000..92759d76383 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(Dialect) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt new file mode 100644 index 00000000000..5ee1a1924ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mhlo) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt new file mode 100644 index 00000000000..e138afa587f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(IR) +add_subdirectory(transforms) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt new file mode 100644 index 00000000000..09bdca84cd3 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -0,0 +1,31 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR +function(add_mlir_hlo_dialect dialect dialect_namespace) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cc.inc -gen-op-defs) + mlir_tablegen(${dialect}_structs.h.inc -gen-struct-attr-decls) + mlir_tablegen(${dialect}_structs.cc.inc -gen-struct-attr-defs) + add_public_tablegen_target(MLIR${dialect}IncGen) + add_dependencies(mlir-headers MLIR${dialect}IncGen) +endfunction() + +add_mlir_hlo_dialect(chlo_ops chlo) +add_mlir_hlo_dialect(hlo_ops mhlo) +add_mlir_hlo_dialect(lhlo_ops lmhlo) + +add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 14a22e92a74..05b22770401 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -32,14 +33,39 @@ namespace mlir { namespace chlo { class HloClientDialect : public Dialect { + void initialize(); + public: - explicit HloClientDialect(MLIRContext *context); + explicit HloClientDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, + TypeID::get()) { + initialize(); + } static StringRef getDialectNamespace() { return "chlo"; } }; +} // namespace chlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" +namespace mlir { +namespace chlo { + +template +static Value getConstantLike(OpBuilder& b, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + + auto getAttr = [&]() -> Attribute { + if (ty.isa()) return b.getIntegerAttr(ty, constant); + if (ty.isa()) return b.getFloatAttr(ty, constant); + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, getAttr(), val); +} + } // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index d7cdd12d351..54b40fe0c94 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -37,7 +37,7 @@ include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLOClient_Dialect : Dialect { let name = "chlo"; - let cppNamespace = "chlo"; + let cppNamespace = "::mlir::chlo"; let summary = [{ Client HLO Ops }]; @@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< //===----------------------------------------------------------------------===// class HLOClient_UnaryElementwiseOp traits, - Type TensorType>: HLOClient_Op { + Type TensorType> : HLOClient_Op { let arguments = (ins TensorType:$operand); - let results = (outs TensorType); + let results = (outs TensorType:$result); + + let assemblyFormat = "$operand attr-dict `:` type($operand)"; } -def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { +def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], + HLO_FpOrComplexTensor> { let summary = "Acos operator"; let description = [{ @@ -364,6 +366,37 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", }]; } +def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [], + HLO_FpOrComplexTensor> { + let summary = "Tan operation"; + + let description = [{ + Returns `Tan(operand)` element-wise. + + $$ + \tan(x) = \sin(x) / \cos(x) + $$ + }]; +} + +def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", + [NoSideEffect, SameOperandsAndResultShape, + InferTypeOpInterface, + DeclareOpInterfaceMethods, + NativeOpTrait<"InferTensorType">]> { + let summary = "Constant like operator"; + + let description = [{ + Returns a splat constant of the same shape as the operand. + }]; + + // TODO(jpienaar): value's type could be tightened. + let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // Broadcasting compare op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 0036cc0dc19..60ee4e613eb 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,7 +19,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -32,11 +31,14 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +// clang-format off +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +// clang-format on + namespace mlir { class OpBuilder; -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" - namespace mhlo { class MhloDialect : public Dialect { @@ -56,22 +58,9 @@ class MhloDialect : public Dialect { void printType(Type type, DialectAsmPrinter &os) const override; }; -namespace HLOTypes { -enum Kind { - Token = Type::FIRST_XLA_HLO_TYPE, -}; -} // namespace HLOTypes - class TokenType : public Type::TypeBase { public: using Base::Base; - - static TokenType get(MLIRContext *context) { - return Base::get(context, HLOTypes::Token); - } - - // Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == HLOTypes::Token; } }; // Shape derivation function that computes the shape of the result based on @@ -90,10 +79,10 @@ LogicalResult deriveShapeFromFirstOperand( OpBuilder *builder, Operation *op, SmallVectorImpl *reifiedReturnShapes); -#define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" - } // end namespace mhlo } // end namespace mlir +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index e83bf874c62..351e8bdae0e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -27,7 +27,7 @@ include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLO_Dialect : Dialect { let name = "mhlo"; - let cppNamespace = "mhlo"; + let cppNamespace = "::mlir::mhlo"; } class HLO_Op traits> : @@ -67,8 +67,7 @@ def HLO_ConstOp : HLO_Op<"constant", "OpBuilder &builder, OperationState &result, Attribute value" >]; - let printer = [{ return Print(*this, &p); }]; - let parser = [{ return ParseConstOp(&parser, &result); }]; + let assemblyFormat = "attr-dict $value"; let hasFolder = 1; @@ -225,11 +224,14 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, - BASE_HLO_NotOp; + BASE_HLO_NotOp { +} def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, - BASE_HLO_NegOp; + BASE_HLO_NegOp { + let hasFolder = 1; +} def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, @@ -263,7 +265,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_SqrtOp; + BASE_HLO_SqrtOp { + let hasFolder = 1; +} def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", [NoSideEffect, SameOperandsAndResultType], @@ -380,6 +384,8 @@ class HLO_BinaryLogicalElementwiseOp : HLO_PredOrIntTensor:$lhs, HLO_PredOrIntTensor:$rhs ); + + let hasFolder = 1; } def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; @@ -492,9 +498,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, BASE_HLO_ReplicaIdOp { - // TODO(prakalps): The output should unsigned 32-bit integer but mlir does - // not differentiate between signed and unsigned int. - let results = (outs I32Tensor); + let results = (outs TensorOf<[UI32]>); } //===----------------------------------------------------------------------===// @@ -671,11 +675,13 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { "OpBuilder &builder, OperationState &results, " "ValueRange values">]; + let hasCanonicalizer = 1; } -def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, - BASE_HLO_CompareOp { +def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, + DeclareOpInterfaceMethods]>, BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -1067,6 +1073,8 @@ def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { ); let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; } def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, @@ -1079,6 +1087,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, // XLA semantics is available. This limitation is because of the current XLA // implementation. let results = (outs I32Tensor); + + let hasFolder = 1; } def HLO_MapOp: HLO_Op<"map", @@ -1143,7 +1153,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, } // TODO(jpienaar): Add broadcastable trait. -def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods]>, BASE_HLO_SelectOp { +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + ]>, BASE_HLO_SelectOp { let arguments = (ins HLO_PredTensor:$pred, HLO_Tensor:$on_true, @@ -1151,6 +1164,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods { } //===----------------------------------------------------------------------===// -// MHLO RngUniform Operator. +// MHLO RNG Operators. //===----------------------------------------------------------------------===// + def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins HLO_PredIntOrFpTensor:$a, @@ -1355,6 +1371,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let hasCustomHLOConverter = 1; } +def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp { + let arguments = (ins + // TODO(jpienaar): This could be an enum instead. + I32Attr:$rng_algorithm, + HLO_IntOrFpTensor:$initial_state + ); + + let results = (outs HLO_TensorOrTuple:$result); + + // TODO(jpienaar): This should not be needed. + let hasCustomHLOConverter = 1; +} + //===----------------------------------------------------------------------===// // MHLO Quantize Operator. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 7f9784d7f11..2f80545ad19 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -316,6 +316,19 @@ class BASE_HLO_RealOp { }]; } +class BASE_HLO_RngBitGeneratorOp { + string summary = "Uniform random number generator operator"; + + string description = [{ + Returns an output with a given shape filled with uniform random bits using + the specified algorithm (or backend default) and returns an updated state + (with the same shape as initial state) and the generated random data. + + See + https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. + }]; +} + class BASE_HLO_RoundOp { string summary = "Round operator"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index e1ae9e1fb89..32940cbc623 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< "hlo::getSplat(&$_builder, $0, " # value # ")">; +class HLO_ConstantLike : NativeCodeCall< + "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index bb9b29096f3..cc24e17c001 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -27,14 +27,17 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { class OpBuilder; +} // namespace mlir #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" +namespace mlir { namespace lmhlo { class LmhloDialect : public Dialect { @@ -43,10 +46,10 @@ class LmhloDialect : public Dialect { static StringRef getDialectNamespace() { return "lmhlo"; } }; -#define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" - } // namespace lmhlo } // end namespace mlir +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 3fa46584ca2..9225d0289dd 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -34,13 +34,14 @@ limitations under the License. #define LHLO_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; - let cppNamespace = "lmhlo"; + let cppNamespace = "::mlir::lmhlo"; } //===----------------------------------------------------------------------===// @@ -81,6 +82,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { ElementsAttr:$value, Arg:$output ); + + let hasCanonicalizer = 1; } def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { @@ -614,11 +617,16 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { ); } -def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp { +def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { let arguments = (ins Arg:$operand, Arg:$output ); + + let extraClassDeclaration = [{ + Value getSource() { return operand();} + Value getTarget() { return output(); } + }]; } def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h index 5773901ad78..cb0af3a159d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h @@ -17,10 +17,11 @@ limitations under the License. #define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ namespace mlir { +class DialectRegistry; namespace mhlo { -void registerAllDialects(); - +// Add chlo, mhlo, lmhlo dialects to the provided registry. +void registerAllMhloDialects(DialectRegistry ®istry); } } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt new file mode 100644 index 00000000000..6de6851b8d7 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt @@ -0,0 +1,23 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set(LLVM_TARGET_DEFINITIONS mhlo_passes.td) +mlir_tablegen(mhlo_passes.h.inc -gen-pass-decls -name MHLO) +add_public_tablegen_target(MLIRMhloPassIncGen) + +set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td) +mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name LMHLO) +add_public_tablegen_target(MLIRLmhloPassIncGen) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 963ff5dbacf..39b4ca65043 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -15,12 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> { - let summary = "Removes redundant LHLO copy operations."; - let constructor = "createLhloCopyRemovalPass()"; -} - - def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { let summary = "Legalize from LHLO dialect to Linalg dialect."; let constructor = "createLegalizeLhloToLinalgPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index c51bcfcfe89..d2621759213 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -40,6 +40,7 @@ using HloToLhloOp = typename HloToLhloOpImpl::Type; MAP_HLO_TO_LHLO(AbsOp); MAP_HLO_TO_LHLO(AddOp); MAP_HLO_TO_LHLO(AndOp); +MAP_HLO_TO_LHLO(Atan2Op); MAP_HLO_TO_LHLO(BroadcastInDimOp); MAP_HLO_TO_LHLO(CeilOp); MAP_HLO_TO_LHLO(ConstOp); @@ -52,6 +53,7 @@ MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); @@ -68,9 +70,11 @@ MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SignOp); MAP_HLO_TO_LHLO(SinOp); +MAP_HLO_TO_LHLO(SliceOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); +MAP_HLO_TO_LHLO(TransposeOp); #undef MAP_HLO_TO_LHLO diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 2bb5ab2888d..1199dae1ab2 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -336,6 +336,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + /// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index fa3bde24df1..aa0f4c317d4 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { let constructor = "createLegalizeControlFlowPass()"; } +def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { + let summary = "Legalize from MHLO control flow to SCF control flow."; + let constructor = "createControlFlowToScfPass()"; +} + def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { let summary = "Legalizes gathers to a torch index select."; let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index efa116f3f0d..fae79d91b1b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -30,11 +30,17 @@ template class OperationPass; class Pass; +// Transforms unranked HLO operations to ranked ones where possible. +std::unique_ptr createTransformUnrankedHloPass(); + namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); +/// Lowers MHLO control flow ops to the SCF dialect. +std::unique_ptr> createControlFlowToScfPass(); + /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); @@ -49,9 +55,6 @@ std::unique_ptr> createLegalizeToLhloPass( // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); -// Transforms unranked HLO operations to ranked ones where possible. -std::unique_ptr> createTransformUnrankedHloPass(); - // Sinks constants implicitly captured in control flow regions. This is // necessary to export to XLA. std::unique_ptr> createSinkConstantsToControlFlowPass(); @@ -92,12 +95,6 @@ std::unique_ptr createLegalizeToGpuPass(); std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops = false, llvm::ArrayRef tile_sizes = {}); -// Removes unnecessary LHLO copies which copy from the allocated buffers to the -// block arguments. The block arguments are used instead of all uses of these -// buffers. The buffers are freed. This pass only works in regions that contain -// a single block. -std::unique_ptr createLhloCopyRemovalPass(); - // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 725155e9403..cf21a95db6f 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/BufferPlacement.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -27,6 +28,12 @@ class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; class BufferAssignmentPlacer; + +// Populates a collection of rewrite patterns to realize element-wise operations +// on ranked tensors where possible. +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. @@ -50,8 +57,9 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, - TypeConverter *converter, OwningRewritePatternList *patterns); + MLIRContext *context, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns); + // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, OwningRewritePatternList *patterns); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h index 1e2404299b2..1c57073f4ab 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h @@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, // Emits shape dialect ops to compute the result shape for a broadcasting // binary elementwise op which broadcasts according to "numpy" semantics -// (see above), returning an extents tensor of the resulting shape. -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder); +// (see above), returning a `shape.shape` or an extent tensor of the resulting +// shape. The result should only be an extent tensor in contexts that ensure +// both operands to be broadcastable. +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor); } // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h index 1e335ae6b82..74ea9c9b1a7 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h @@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Returns DenseElementsAttr of rank zero with the given element type and the // value. -// Requires `ty` to be either FloatType of IntegerType. +// Requires `ty` to be either FloatType, IntegerType, or ComplexType. DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); +// Enum type used to specify scalar argument to GetScalarLimitOfType. +enum ScalarLimit { + kLowest, // The scalar corresponding to numeric_limits::lowest. + kInfinityLowest, // Like kMax, but returns -infinity where available. + kMax, // The scalar corresponding to numeric_limits::max. + kInfinityMax, // Like kMax, but returns infinity where available. +}; + +// Returns a scalar limit value for the given type. +// +// The argument 'limit' describes which scalar value to return. +// +// Requires `ty` to be either FloatType or IntegerType. +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit); + } // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt new file mode 100644 index 00000000000..ec65a5ee882 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(Dialect) +add_subdirectory(utils) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt new file mode 100644 index 00000000000..5ee1a1924ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mhlo) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt new file mode 100644 index 00000000000..e138afa587f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(IR) +add_subdirectory(transforms) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt new file mode 100644 index 00000000000..d7bb5057b00 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -0,0 +1,82 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(LLVM_TARGET_DEFINITIONS hlo_patterns.td) +mlir_tablegen(hlo_patterns.cc.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloRewriterIncGen) + +set(LLVM_TARGET_DEFINITIONS mhlo_canonicalize.td) +mlir_tablegen(mhlo_canonicalize.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloCanonicalizeIncGen) + +add_mlir_dialect_library(ChloDialect + chlo_ops.cc + + DEPENDS + MLIRchlo_opsIncGen +) +target_link_libraries(ChloDialect PUBLIC MLIRIR) + +add_mlir_library(MhloInferFusibilityOpInterface + infer_fusibility_op_interface.cc + + DEPENDS + MLIRinfer_fusibility_op_interfaceIncGen +) + + +add_mlir_dialect_library(MhloDialect + hlo_ops.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloCanonicalizeIncGen + MLIRMhloRewriterIncGen + MLIRinfer_fusibility_op_interfaceIncGen +) +target_link_libraries(MhloDialect + PUBLIC + MLIRIR + MhloInferFusibilityOpInterface + MLIRMhloUtils +) + + +add_mlir_dialect_library(LmhloDialect + lhlo_ops.cc + + DEPENDS + MLIRlhlo_opsIncGen +) +target_link_libraries(LmhloDialect PUBLIC MLIRIR) + + +add_mlir_dialect_library(MhloRegisterDialects + init.cc +DEPENDS + MLIRchlo_opsIncGen + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen +) +target_link_libraries(MhloRegisterDialects + PUBLIC + ChloDialect + MhloDialect + LmhloDialect +) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td new file mode 100644 index 00000000000..eb92d9e0e46 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the canonicalize pattern definition file. + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" + +def UnaryToBinaryEinsumEq : NativeCodeCall< + "$_builder.getStringAttr(\",\" + $0.getValue().str())">; + +// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first +// operand. +def UnaryEinsumToEinsum : Pat< + (HLO_UnaryEinsumOp $operand, $equation), + (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), + $operand, (UnaryToBinaryEinsumEq $equation))>; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index 99ed8bcb849..99b22a75a14 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -15,10 +15,12 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/utils/broadcast_utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -151,7 +153,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( } Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( - loc, lhs, rhs, builder); + loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); return success(); @@ -259,15 +261,62 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS +static LogicalResult Verify(ConstantLikeOp op) { + if (op.value().getType() != op.getType().cast().getElementType()) + return op.emitOpError() << "value's type doesn't match element return type"; + return success(); +} + +LogicalResult ConstantLikeOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + ConstantLikeOp::Adaptor op(operands, attributes); + if (failed(op.verify(location.getValue()))) return failure(); + Type element_type = op.value().getType(); + Type operand_type = op.operand().getType(); + if (operand_type.isa()) { + inferedReturnShapes.emplace_back(element_type); + } else { + const auto& shape = operand_type.cast().getShape(); + inferedReturnShapes.emplace_back(shape, element_type); + } + return success(); +} + +struct ConstantLikeToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstantLikeOp op, + PatternRewriter& rewriter) const override { + auto op_type = op.operand().getType().cast(); + if (!op_type.hasStaticShape()) return failure(); + auto type = RankedTensorType::get(op_type.getShape(), op.value().getType()); + ElementsAttr attr = DenseElementsAttr::get(type, op.value()); + rewriter.replaceOpWithNewOp(op.getOperation(), attr); + return success(); + } +}; + +void ConstantLikeOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +} // namespace chlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +namespace mlir { +namespace chlo { + //===----------------------------------------------------------------------===// // chlo Dialect Constructor //===----------------------------------------------------------------------===// -HloClientDialect::HloClientDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { +void HloClientDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 69b01009a0d..6711a916896 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -60,7 +60,11 @@ limitations under the License. namespace mlir { #include "hlo_patterns.cc.inc" +} // namespace mlir + #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" + +namespace mlir { namespace mhlo { Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, @@ -112,37 +116,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, // ConstOp //===----------------------------------------------------------------------===// -static void Print(ConstOp op, OpAsmPrinter* printer) { - // Print op name. - *printer << op.getOperationName(); - - // Elide attribute value while printing the attribute dictionary. - SmallVector elided_attrs; - elided_attrs.push_back("value"); - printer->printOptionalAttrDict(op.getAttrs(), elided_attrs); - - *printer << ' ' << op.value(); -} - -static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) { - if (parser->parseOptionalAttrDict(result->attributes)) return failure(); - - // If colon is not present after attribute dictionary, it should be short form - // and attribute 'value' is outside the dictionary. - if (failed(parser->parseOptionalColon())) { - Attribute value; - if (parser->parseAttribute(value, "value", result->attributes)) - return failure(); - return parser->addTypeToList(value.getType(), result->types); - } - - // Long form should have type of the result after colon. - Type ty; - if (parser->parseType(ty)) return failure(); - result->types.push_back(ty); - return success(); -} - OpFoldResult ConstOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); @@ -196,6 +169,71 @@ static LogicalResult Verify(DotGeneralOp op) { return success(); } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +// Converts gather ops to slice ops in case we have a single set of constant +// indices. +struct GatherSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter& rewriter) const override { + DenseIntElementsAttr index; + if (!matchPattern(gather.start_indices(), m_Constant(&index))) + return failure(); + + const auto& dnums = gather.dimension_numbers(); + if (dnums.collapsed_slice_dims().getNumElements() != 0 || + dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) + return failure(); + + // TODO(tberghammer): Remove when the verifier catches this case what is + // invalid if all previous condition holds. + if (index.getNumElements() != dnums.start_index_map().getNumElements()) + return failure(); + + auto slice_end = + llvm::to_vector<8>(gather.slice_sizes().getValues()); + llvm::SmallVector slice_start(slice_end.size(), 0); + for (auto it : llvm::zip(dnums.start_index_map().getIntValues(), + index.getIntValues())) { + int64_t map_index = std::get<0>(it).getSExtValue(); + int64_t offset = std::get<1>(it).getSExtValue(); + slice_start[map_index] += offset; + slice_end[map_index] += offset; + } + + llvm::SmallVector slice_stride(slice_end.size(), 1); + rewriter.replaceOpWithNewOp( + gather, gather.getType(), gather.getOperand(0), + GetI64ElementsAttr(slice_start, &rewriter), + GetI64ElementsAttr(slice_end, &rewriter), + GetI64ElementsAttr(slice_stride, &rewriter)); + return success(); + } +}; + +void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// GetDimensionSizeOp +//===----------------------------------------------------------------------===// + +/// Fold get_dimension_size when the said shape dimension is a constant. +OpFoldResult GetDimensionSizeOp::fold(ArrayRef attrs) { + RankedTensorType type = operand().getType().cast(); + int32_t dim = dimension(); + if (type.isDynamic(dim)) return {}; + // The result type is always is a 0-d i32 tensor. + return DenseIntElementsAttr::get( + getResult().getType().cast(), type.getDimSize(dim)); +} + //===----------------------------------------------------------------------===// // IotaOp //===----------------------------------------------------------------------===// @@ -207,7 +245,7 @@ static LogicalResult Verify(IotaOp op) { if (shape.getRank() == 0) return op.emitOpError() << "does not support scalars."; - auto iota_dimension = op.iota_dimension().getSExtValue(); + auto iota_dimension = op.iota_dimension(); if (iota_dimension >= shape.getRank() || iota_dimension < 0) return op.emitOpError() << "iota dimension cannot go beyond the output " "rank or be negative."; @@ -229,8 +267,7 @@ struct IotaBroadcast : public OpRewritePattern { auto iota_dimension = iota.iota_dimension(); auto iota_type = RankedTensorType::get( - {result_ty.getDimSize(iota_dimension.getLimitedValue())}, - result_ty.getElementType()); + {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType()); auto new_iota = rewriter.create(iota.getLoc(), iota_type, rewriter.getI64IntegerAttr(0)); @@ -250,7 +287,7 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results, } OpFoldResult IotaOp::fold(ArrayRef operands) { - auto dimension = iota_dimension().getLimitedValue(); + auto dimension = iota_dimension(); auto result_ty = getResult().getType().cast(); if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) { Builder builder(getContext()); @@ -294,7 +331,7 @@ struct DynamicIotaBroadcast : public OpRewritePattern { } auto iota_dimension = iota.iota_dimension(); - auto iota_dimension_int = iota_dimension.getLimitedValue(); + auto iota_dimension_int = iota_dimension; auto converted_shape = rewriter.create( iota.getLoc(), @@ -340,6 +377,33 @@ void DynamicIotaOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// DynamicUpdateSliceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicUpdateSliceOp op) { + OperandRange indices = op.start_indices(); + if (indices.size() <= 1) return success(); + + // Note: start_indices is constrained to Variadic, so it + // is OK to cast indices to ShapedType here. + auto idx_tensor = indices.take_front().front().getType().cast(); + Type first_elem_ty = idx_tensor.getElementType(); + Type elem_ty; + + for (auto idx : llvm::drop_begin(indices, 1)) { + idx_tensor = idx.getType().cast(); + elem_ty = idx_tensor.getElementType(); + + if (first_elem_ty != elem_ty) { + return op.emitOpError() << "start indices must have same element type " + "(encountered mismatch: " + << first_elem_ty << " vs " << elem_ty << ")"; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // AbsOp //===----------------------------------------------------------------------===// @@ -466,7 +530,7 @@ static LogicalResult Verify(DequantizeOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(GetTupleElementOp op) { - auto indexVal = op.index().getZExtValue(); + auto indexVal = op.index(); auto operandType = op.getOperand().getType().cast(); if (indexVal >= operandType.size()) { return op.emitOpError( @@ -485,7 +549,7 @@ static LogicalResult Verify(GetTupleElementOp op) { OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { if (auto tupleOp = dyn_cast_or_null(getOperand().getDefiningOp())) { - return tupleOp.getOperand(index().getLimitedValue()); + return tupleOp.getOperand(index()); } return {}; @@ -506,6 +570,46 @@ static LogicalResult Verify(TupleOp op) { return success(); } +namespace { + +// Pattern for unpacking and repacking the same tuple. +struct UnpackRepackSameTuple : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TupleOp op, + PatternRewriter& rewriter) const override { + if (op.val().empty()) return failure(); + + Value first_element = op.val().front(); + auto first_element_op = + dyn_cast_or_null(first_element.getDefiningOp()); + if (!first_element_op || first_element_op.indexAttr().getInt() != 0) + return failure(); + + Value tuple_predecessor = first_element_op.getOperand(); + if (tuple_predecessor.getType() != op.getType()) return failure(); + + for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { + auto element_op = dyn_cast_or_null( + element_and_idx.value().getDefiningOp()); + if (!element_op || + element_op.indexAttr().getInt() != element_and_idx.index() + 1 || + element_op.getOperand() != tuple_predecessor) + return failure(); + } + + rewriter.replaceOp(op, tuple_predecessor); + return success(); + } +}; + +} // namespace + +void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AllToAllOp //===----------------------------------------------------------------------===// @@ -515,8 +619,8 @@ static LogicalResult Verify(AllToAllOp op) { // count. auto type = op.getOperand().getType().dyn_cast(); if (!type) return success(); - auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue()); - auto split_count = op.split_count().getSExtValue(); + auto split_dim_size = type.getDimSize(op.split_dimension()); + auto split_count = op.split_count(); if (split_dim_size % split_count != 0) { return op.emitError() << "split dimension has size " << split_dim_size << ", expected to be a multiple of split_count " @@ -708,10 +812,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { auto dimSize = operandType.getDimSize(i); auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { + // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we + // add a manual check for this. + if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) { return op.emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", + llvm::formatv("size of operand dimension {0} ({1}) is not compatible " + "with size of result dimension {2} ({3})", i, dimSize, dimIndex, resultDimSize)); } } @@ -862,7 +968,7 @@ class ConcatenateOperandRemoval : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter& rewriter) const override { - auto axis = op.dimension().getLimitedValue(); + auto axis = op.dimension(); llvm::SmallVector new_operands; for (auto operand : op.getOperands()) { auto ty = operand.getType().cast(); @@ -903,13 +1009,38 @@ LogicalResult ConcatenateOp::inferReturnTypes( } } - // If an input is unranked the output shape is unranked. + // Find the first ranked input to determine the output rank. + for (auto type : operands.getTypes()) { + auto shaped_type = type.cast(); + if (shaped_type.hasRank()) { + first_type = shaped_type; + break; + } + } + + // If all inputs are unranked, the result must be unranked. if (!first_type.hasRank()) { inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); return success(); } auto out_shape = llvm::to_vector<6>(first_type.getShape()); + + // Determine what the non-concatenate dimensions should be. + for (auto type : operands.getTypes()) { + auto shaped_ty = type.cast(); + if (!shaped_ty.hasRank()) { + continue; + } + + for (auto it : llvm::enumerate(shaped_ty.getShape())) { + // If a dimension is not dynamic, the output shape should match. + if (ShapedType::isDynamic(out_shape[it.index()])) { + out_shape[it.index()] = it.value(); + } + } + } + out_shape[dimension] = 0; for (auto operand : operands.getTypes()) { @@ -942,7 +1073,7 @@ void ConcatenateOp::getCanonicalizationPatterns( template static Attribute foldConcatenateHelper(ConcatenateOp* op, ArrayRef operands) { - auto axis = op->dimension().getLimitedValue(); + auto axis = op->dimension(); auto type = op->getType().cast(); SmallVector values; @@ -990,7 +1121,7 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { ShapedType type = getResult().getType().cast(); if (!type.hasStaticShape()) return {}; - auto axis = dimension().getLimitedValue(); + auto axis = dimension(); if (auto attr = foldConcatenate(this, operands)) { return attr; } @@ -1165,6 +1296,131 @@ static LogicalResult Verify(InfeedOp op) { return success(); } +//===----------------------------------------------------------------------===// +// Logical Ops +//===----------------------------------------------------------------------===// + +OpFoldResult AndOp::fold(ArrayRef operands) { + if (lhs() == rhs()) return lhs(); + + auto rType = getType().cast(); + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return rhs(); + } + + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhsVal; + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return lhs(); + } + + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhsVal; + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) & std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + +OpFoldResult OrOp::fold(ArrayRef operands) { + if (lhs() == rhs()) return lhs(); + + auto rType = getType().cast(); + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return lhsVal; + } + + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhs(); + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue() + .cast() + .getValue() + .isAllOnesValue()) { + return rhsVal; + } + + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhs(); + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) | std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + +OpFoldResult XorOp::fold(ArrayRef operands) { + auto rType = getType().cast(); + if (lhs() == rhs()) { + Builder builder(getContext()); + return builder.getZeroAttr(rType); + } + + auto lhsVal = operands[0].dyn_cast_or_null(); + auto rhsVal = operands[1].dyn_cast_or_null(); + + if (lhsVal && lhsVal.isSplat()) { + if (lhsVal.getSplatValue().cast().getValue().isNullValue()) { + return rhs(); + } + } + + if (rhsVal && rhsVal.isSplat()) { + if (rhsVal.getSplatValue().cast().getValue().isNullValue()) { + return lhs(); + } + } + + if (!rhsVal || !lhsVal) return {}; + + llvm::SmallVector values; + values.reserve(rhsVal.getNumElements()); + for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) { + values.push_back(std::get<0>(it) ^ std::get<1>(it)); + } + + return DenseIntElementsAttr::get(rType, values); +} + //===----------------------------------------------------------------------===// // MapOp //===----------------------------------------------------------------------===// @@ -1358,6 +1614,29 @@ static LogicalResult Verify(SelectOp op) { return success(); } +OpFoldResult SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) { + return on_true(); + } + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) { + return {}; + } + + auto predicateTy = predicate.getType().cast(); + if (!predicateTy.getElementType().isInteger(1)) { + return {}; + } + + if (predicate.isSplat()) { + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); + } + + return {}; +} + // Makes it such that a SelectOp that is a non-root operation in a DRR infers // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( @@ -1399,6 +1678,20 @@ LogicalResult SelectOp::inferReturnTypes( return success(); } +LogicalResult SelectOp::inferReturnTypeComponents( + mlir::MLIRContext*, llvm::Optional, mlir::ValueRange, + mlir::DictionaryAttr, mlir::RegionRange, + llvm::SmallVectorImpl&) { + // TODO(b/168772852) + return failure(); +} + +LogicalResult SelectOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// @@ -1546,6 +1839,79 @@ static LogicalResult Verify(CaseOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +OpFoldResult SqrtOp::fold(ArrayRef operands) { + auto val = operands[0].dyn_cast_or_null(); + if (!val) return {}; + + auto type = getElementTypeOrSelf(getType()); + if (!type.isF32() && !type.isF64()) return {}; + + auto shaped_type = getType().cast(); + if (!shaped_type.hasStaticShape()) return {}; + + int bit_width = type.getIntOrFloatBitWidth(); + llvm::SmallVector values; + values.reserve(val.getNumElements()); + for (auto it : val.getFloatValues()) { + double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble(); + if (value < 0) return {}; + value = std::sqrt(value); + if (bit_width == 32) + values.emplace_back(static_cast(value)); + else + values.emplace_back(value); + } + return DenseFPElementsAttr::get(shaped_type, values); +} + +//===----------------------------------------------------------------------===// +// UnaryOps +//===----------------------------------------------------------------------===// + +template +static Attribute UnaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0]) return {}; + + DenseElementsAttr val = attrs[0].dyn_cast(); + if (!val) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(val.getNumElements()); + for (const auto v : val.getValues()) { + values.push_back(Convert()(v)); + } + + return DenseElementsAttr::get(type, values); +} + +#define UNARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + return {}; \ + } + +UNARY_FOLDER(NegOp, std::negate); + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// @@ -1720,11 +2086,11 @@ static Attribute FoldSlice(SliceOp* op, I values) { OpFoldResult SliceOp::fold(ArrayRef operands) { // Check if the SliceOp is a NoOp operation. - auto operand_shape = getOperand().getType().cast().getShape(); + auto operand_type = getOperand().getType().cast(); auto result_type = getResult().getType().cast(); - auto result_shape = result_type.getShape(); - if (result_type.hasStaticShape() && (operand_shape == result_shape)) { + if (operand_type.hasStaticShape() && result_type.hasStaticShape() && + (operand_type.getShape() == result_type.getShape())) { return getOperand(); } @@ -1770,7 +2136,7 @@ struct SimplifyConcatSlice : public OpRewritePattern { return failure(); } - auto dimension = concat.dimension().getSExtValue(); + auto dimension = concat.dimension(); auto start = slice.start_indices().getIntValues(); auto limit = slice.limit_indices().getIntValues(); @@ -1920,7 +2286,7 @@ static LogicalResult Verify(SortOp op) { return op.emitOpError("requires all inputs to have the same dimensions"); int64_t rank = input_shape.size(); - int64_t cmp_dim = op.dimension().getSExtValue(); + int64_t cmp_dim = op.dimension(); if (cmp_dim < -rank || cmp_dim >= rank) return op.emitOpError("dimension attribute value must be in range [-") << rank << ", " << rank << "), but found " << cmp_dim; @@ -2121,9 +2487,28 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, build(builder, result, new_type, lhs, rhs, comparison_direction); } +LogicalResult CompareOp::inferReturnTypeComponents( + mlir::MLIRContext*, llvm::Optional, mlir::ValueRange, + mlir::DictionaryAttr, mlir::RegionRange, + llvm::SmallVectorImpl&) { + // TODO(b/168772852) + return failure(); +} + +LogicalResult CompareOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); +} + +} // namespace mhlo +} // namespace mlir #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +namespace mlir { +namespace mhlo { + //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces //===----------------------------------------------------------------------===// @@ -2150,7 +2535,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface { //===----------------------------------------------------------------------===// MhloDialect::MhloDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { + : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc index 9fffeae1cc5..503b100c7ab 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc @@ -18,16 +18,10 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h" -// Static initialization for *HLO dialects registration. - -void mlir::mhlo::registerAllDialects() { - static bool init_once = []() { - registerDialect(); - registerDialect(); - registerDialect(); - return true; - }(); - (void)init_once; - - // Dependent dialects +void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) { + // clang-format off + registry.insert(); + // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index bbb463cd1a9..cba0d3b4788 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -29,6 +29,8 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -45,17 +47,48 @@ limitations under the License. #include "mlir/IR/Value.h" namespace mlir { -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" namespace lmhlo { LmhloDialect::LmhloDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { + : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" >(); } +//===----------------------------------------------------------------------===// +// ConstOp. +//===----------------------------------------------------------------------===// + +/// An lho.constant on an memref that is locally allocated and with no other +/// users (other than dealloc's) can be erased. +// TODO: This can be generalized to an arbitrary op by making use of memory +// effects (write memory effect). +struct EraseConstOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstOp op, + PatternRewriter& rewriter) const override { + Value memref = op.output(); + if (!memref.getDefiningOp()) { + return failure(); + } + + // Check that all uses of the memref are either DeallocOps or this op. + for (Operation* user : memref.getUsers()) + if (user != op && !isa(user)) return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // StaticMemRefCastOp //===----------------------------------------------------------------------===// @@ -126,9 +159,15 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) { return success(); } +} // namespace lmhlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +namespace mlir { +namespace lmhlo { + // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. void FusionOp::build(OpBuilder &builder, OperationState &result, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt new file mode 100644 index 00000000000..e02add4353a --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -0,0 +1,160 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(LLVM_TARGET_DEFINITIONS lower_complex_patterns.td) +mlir_tablegen(generated_lower_complex.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloLowerComplexIncGen) + +set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td) +mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) + +set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td) +mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters) +add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) + + +add_mlir_library(ChloPasses + chlo_legalize_to_hlo.cc + chlo_legalize_to_hlo_pass.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRChloLegalizeToHloIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ChloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(MhloPasses + legalize_gather_to_torch_index_select.cc + legalize_tanh_to_approximation.cc + lower_complex.cc + lower_complex_patterns.td + lower_general_dot.cc + materialize_broadcasts.cc + materialize_broadcasts_pass.cc + mhlo_fusion.cc + optimize_mhlo.cc + optimize_mhlo_pass.cc + sink_constants_to_control_flow.cc + test_infer_shaped_type_pass.cc + transform_unranked_hlo.cc + unfuse_batch_norm.cc + unfuse_batch_norm_pass.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloLowerComplexIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRMhloUtils + MLIRPass + MLIRTransformUtils +) + +add_mlir_library(MhloToLhloConversion + hlo_legalize_to_lhlo.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MhloDialect + LmhloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(MhloToStandard + legalize_control_flow.cc + legalize_to_standard.cc + mhlo_control_flow_to_scf.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + MLIRMhloLegalizeToStandardIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass +) + +add_mlir_library(MhloLhloToLinalg + legalize_to_linalg.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MhloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(LmhloPasses + lhlo_fuse_linalg.cc + lhlo_legalize_to_affine.cc + lhlo_legalize_to_gpu.cc + lhlo_legalize_to_llvm.cc + lhlo_legalize_to_llvm_pass.cc + lhlo_legalize_to_parallel_loops.cc + + DEPENDS + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + LmhloDialect + MLIRIR + MLIRPass +) + +add_library(AllMhloPasses INTERFACE) +target_link_libraries(AllMhloPasses INTERFACE + ChloPasses + MhloPasses + MhloToLhloConversion + MhloToStandard + MhloLhloToLinalg + LmhloPasses +) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index adbd2e5a628..626b5d3bd59 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -31,6 +33,39 @@ namespace mlir { namespace chlo { namespace { +struct ConvertConstantLikeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + ConstantLikeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto result_ty = op.getType().cast(); + + // Unranked uses are not supported. Consider `transform-unranked-hlo`. + if (!result_ty.hasRank()) return failure(); + + // Lower to MHLO constant if statically shaped. + if (result_ty.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(result_ty, op.value())); + return success(); + } + + // Lower to broadcasted constant. + ConstantLikeOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type extent_tensor_type = shape::getExtentTensorType(op.getContext()); + Value constant = rewriter.create(loc, op.value()); + Value uncasted_shape = rewriter.create( + loc, extent_tensor_type, transformed.operand()); + Type shape_ty = + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); + Value shape = rewriter.create(loc, shape_ty, uncasted_shape); + rewriter.replaceOpWithNewOp( + op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -124,8 +159,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, - rewriter); + hlo::ComputeBinaryElementwiseBroadcastingResultExtents( + loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true); // Note that we unconditionally emit DynamicBroadcastInDim ops and let // downstream canonicalizations fold them away if possible. This is @@ -338,30 +373,37 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Value lhs_shape = if_builder.create(loc, lhs); Value rhs_shape = if_builder.create(loc, rhs); SmallVector ranked_shape(targeted_rank, 1); - auto extent_tensor_type = + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = RankedTensorType::get({targeted_rank}, builder.getIndexType()); auto reshaped_type = RankedTensorType::get( llvm::SmallVector(targeted_rank, RankedTensorType::kDynamicSize), lhs.getType().template dyn_cast().getElementType()); Value ranked_shape_val = if_builder.create( - loc, extent_tensor_type, - mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape)); - // TODO(tpopp): Return extent tensors when possible to signal that this is a - // guaranteed safe broadcast by construction. + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); Value extended_lhs = if_builder.create( - loc, extent_tensor_type, lhs_shape, ranked_shape_val, nullptr); + loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, + nullptr); + Value extended_lhs_casted = if_builder.create( + loc, known_rank_extent_tensor_type, extended_lhs); Value extended_rhs = if_builder.create( - loc, extent_tensor_type, rhs_shape, ranked_shape_val, nullptr); + loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, + nullptr); + Value extended_rhs_casted = if_builder.create( + loc, known_rank_extent_tensor_type, extended_rhs); // 1. Reshape operands to the given rank (with the same number of elements) // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops // can be broadcasted and do the actual broadcasting) // 3. Type erase the output back to unranked Value reshaped_lhs = if_builder.create( - loc, reshaped_type, lhs, extended_lhs); + loc, reshaped_type, lhs, extended_lhs_casted); Value reshaped_rhs = if_builder.create( - loc, reshaped_type, rhs, extended_rhs); + loc, reshaped_type, rhs, extended_rhs_casted); Value result = if_builder.create( loc, ArrayRef{reshaped_type}, ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); @@ -469,10 +511,13 @@ struct HloCompareAdaptor { } }; +#include "generated_chlo_legalize_to_hlo.inc" } // namespace void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { + populateWithGenerated(context, patterns); + // Instantiate conversion templates for conforming binary elementwise ops // that do not have different dtypes between operands and results and do // not have special attributes that need to be preserved. @@ -502,6 +547,9 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns); PopulateForBinaryOp( context, patterns); + + // Other patterns. + patterns->insert(context); } } // namespace chlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 50cd6df5c99..263b6cdd1c3 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -29,6 +29,10 @@ namespace { struct TestChloLegalizeToHloPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td new file mode 100644 index 00000000000..7b612ff4b02 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for CHLO to MHLO. + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +// Expand acos to MHLO dialect as follows: +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// = pi if x == -1 +def : Pat<(HLOClient_AcosOp $input), + (HLO_SelectOp + (HLO_CompareOp $input, + (HLO_ConstantLike<"0"> $input), + HLO_COMPARISON_DIRECTION_NE + ), + (HLO_MulOp + (HLO_ConstantLike<"2.0f"> $input), + (HLO_Atan2Op + (HLO_SqrtOp + (HLO_SubOp + (HLO_ConstantLike<"1"> $input), + (HLO_MulOp $input, $input) + ) + ), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + $input + ) + ) + ), + (HLO_ConstantLike<"M_PI"> $input))>; + +// Express tan in MHLO dialect as +// tan(x) = sin(x) / cos(x). +def : Pat<(HLOClient_TanOp $input), + (HLO_DivOp + (HLO_SinOp $input), + (HLO_CosOp $input) + )>; + diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a8c3ad17ebb..0f1a3d034eb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -78,7 +78,6 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, } Value InsertAlloc(Location loc, OpResult result, - BufferAssignmentPlacer* bufferAssignment, ConversionPatternRewriter* rewriter) { auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { @@ -88,8 +87,7 @@ Value InsertAlloc(Location loc, OpResult result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); OpBuilder::InsertionGuard guard(*rewriter); - rewriter->restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); + rewriter->setInsertionPoint(result.getDefiningOp()); auto alloc = rewriter->create(loc, memref_type); return alloc; } @@ -111,8 +109,8 @@ class HloToLhloOpConverter : public BaseOpConversion { return failure(); } if (resultType.hasStaticShape()) { - buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), - this->bufferAssignment, &rewriter)); + buffer_args.push_back( + InsertAlloc(op->getLoc(), result.value(), &rewriter)); } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); @@ -259,8 +257,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { - buffer_args.push_back( - InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); + buffer_args.push_back(InsertAlloc(loc, result, &rewriter)); } auto new_op = rewriter.create(loc, llvm::None, buffer_args, op.getAttrs()); @@ -290,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; -// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality -// is provided by mlir buffer assignment, so use the pattern from there. -// TODO(DFKI): Move this out of detail. -using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter< - mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. +struct HloToLhloReturnOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ReturnOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + auto& entry_block = op.getParentRegion()->front(); + auto num_arguments = entry_block.getNumArguments(); + if (operands.size() > num_arguments) { + return op.emitError( + "The number of operands that need Copy operations is more " + "than the number of target function arguments."); + } + + // The index of the first output block argument. + auto dest_arg_idx = num_arguments - operands.size(); + + // Create a lmhlo.copy for each operand of mhlo.return. + for (Value operand : operands) { + rewriter.create(loc, operand, + entry_block.getArgument(dest_arg_idx)); + ++dest_arg_idx; + } + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; class HloToLhloTensorLoadOpConverter : public BaseOpConversion { @@ -388,6 +410,10 @@ class HloToLhloTensorStoreOpConverter struct HloLegalizeToLhlo : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: HloLegalizeToLhlo() = default; HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { @@ -428,28 +454,19 @@ struct HloLegalizeToLhlo isMemRefType); }); - auto module = getOperation(); - WalkResult result = module.walk([&](FuncOp func) -> WalkResult { - BufferAssignmentPlacer bufferAssignment(func); - OwningRewritePatternList patterns; - populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, - &converter, &patterns); - if (results_escape_function) { - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, - /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, - &converter, &patterns); - } else { - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, - /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, - &converter, &patterns); - } - return applyPartialConversion(func, target, patterns); - }); - if (result.wasInterrupted()) { + auto kind = results_escape_function + ? BufferAssignmentTypeConverter::KeepAsFunctionResult + : BufferAssignmentTypeConverter::AppendToArgumentsList; + converter.setResultConversionKind( + kind); + converter.setResultConversionKind(kind); + + populateHLOToLHLOConversionPattern(&context, &converter, &patterns); + populateWithBufferAssignmentOpConversionPatterns< + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter, + &patterns); + if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); - } } private: @@ -462,8 +479,8 @@ struct HloLegalizeToLhlo } // namespace void populateHLOToLHLOConversionPattern( - MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, - TypeConverter* converter, OwningRewritePatternList* patterns) { + MLIRContext* context, BufferAssignmentTypeConverter* converter, + OwningRewritePatternList* patterns) { // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, @@ -471,6 +488,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -483,6 +501,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -497,14 +516,17 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloReduceOpConverter, HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter - >(context, bufferAssignment, converter); + >(context, converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index f47f2c2fbdc..0a8105eb366 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -15,6 +15,8 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. +#include + #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern { unsigned currSrcDim = 0, currDstDim = 0; SmallVector reassociationMap( dstShape.size()); + bool isExpandingOrCollapsing = true; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; @@ -619,11 +622,48 @@ class ReshapeOpConverter : public OpConversionPattern { } } } else { - return failure(); + isExpandingOrCollapsing = false; + break; } currDstDim++; } - if (currSrcDim != srcShape.size()) return failure(); + if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) + isExpandingOrCollapsing = false; + + if (!isExpandingOrCollapsing) { + auto getIdentityExprs = [&rewriter](int n) { + SmallVector exprs; + for (int i = 0; i < n; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + return exprs; + }; + Location loc = reshapeOp.getLoc(); + int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, + std::multiplies()); + auto elemType = operandType.getElementType(); + SmallVector collapsingMap = { + getIdentityExprs(dstShape.size())}; + SmallVector expandingMap = { + getIdentityExprs(srcShape.size())}; + + if (isLHLO) { + auto collapsedType = MemRefType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + Value reshapeBuffer = rewriter.create( + loc, resultType, collapsedOp, expandingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + auto collapsedType = RankedTensorType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, collapsedOp, expandingMap); + } + return success(); + } if (isLHLO) { Value reshapeBuffer = rewriter.create( @@ -665,7 +705,7 @@ class IotaConverter : public OpConversionPattern { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange args) { Value castOp = nestedBuilder.create( - nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], + nestedLoc, ivs[iotaOp.iota_dimension()], nestedBuilder.getIntegerType( resultElementType.getIntOrFloatBitWidth())); if (resultElementType.template isa()) { @@ -783,6 +823,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -801,7 +842,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, - SliceConverter + SliceConverter, + TransposeConverter >(context); // clang-format on } @@ -827,6 +869,10 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -843,6 +889,10 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -882,6 +932,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index cc574e008d5..d2d4bab45ab 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -117,7 +117,7 @@ class ConvertIotaOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); auto output_size = output_type.getNumElements(); - auto dimension = op.iota_dimension().getSExtValue(); + auto dimension = op.iota_dimension(); auto max_dim_size = output_type.getDimSize(dimension); auto element_type = output_type.getElementType(); @@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern { namespace { struct LegalizeToStandardPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + /// Perform the lowering to Standard dialect. void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index ea67c052c5c..6ee6f124628 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -36,6 +36,10 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint; +// Unary Lowering Patterns. +def : Pat<(HLO_CeilOp HLO_FpTensor:$i), (CeilFOp $i)>; + +// Binary Lowering Patterns. def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc deleted file mode 100644 index 7a4418466b5..00000000000 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements a pass to remove redundant LHLO copy operations. - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace lmhlo { -namespace { - -// Removes LHLO copy operations that copy from allocated buffers to block -// arguments. All uses of each buffer are replaced with the corresponding block -// argument and the buffer is freed. Note that this pass only works in regions -// with a single block. -struct LhloCopyRemovalPass - : mlir::PassWrapper> { - void runOnOperation() override { - llvm::SmallVector eraseList; - auto operation = getOperation(); - operation->walk([&](mlir::lmhlo::CopyOp copyOp) { - // If this region contains more than one block, then ignore this copy - // operation. - if (copyOp.getParentRegion()->getBlocks().size() > 1) { - return; - } - - mlir::Value fromOperand = copyOp.operand(); - mlir::Value toOperand = copyOp.output(); - - // If the fromOperand value is a block argument or the toOperand - // value is not a block argument, then ignore this copy operation. - if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) { - return; - } - - // The copy operation removal is illegal if there is at least a single use - // of toOperand value that lies between the first use of fromOperand value - // and the copy operation. - auto fromOperandUsers = fromOperand.getUsers(); - auto firstUser = *fromOperandUsers.begin(); - for (auto op : fromOperandUsers) { - if (op->isBeforeInBlock(firstUser)) firstUser = op; - } - for (auto op : toOperand.getUsers()) { - if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) { - return; - } - } - - // TODO(DFKI): Use live variable analysis to solve aliasing issues among - // block arguments. - - // Remove the associated alloc operation. - auto allocOp = fromOperand.getDefiningOp(); - eraseList.push_back(allocOp); - - // Iterate over all uses of the fromOperand to find the associated - // deallocOp (if any). - for (auto op : fromOperandUsers) { - if (isa(op)) { - eraseList.push_back(op); - break; - } - } - - // Replace all uses of the fromOperand with the toOperand. This rewires - // all references pointing to the original alloc operation to the new - // target operation in order to safely remove the copy op. - fromOperand.replaceAllUsesWith(toOperand); - copyOp.erase(); - }); - for (auto op : eraseList) { - op->erase(); - } - }; -}; - -} // namespace - -std::unique_ptr createLhloCopyRemovalPass() { - return std::make_unique(); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 1467f015dc9..6dc5b64a105 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -19,8 +19,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" @@ -33,6 +35,10 @@ using linalg::LinalgOp; class LhloFuseLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LhloFuseLinalgPass() = default; LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 07891327775..2771afc6302 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -139,6 +139,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, struct LhloLegalizeToAffinePass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 0d0b8b0ab6e..fbade8f7387 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -20,8 +20,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -147,9 +149,9 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Now copy over the actual body of the reduction, leaving out the // terminator. BlockAndValueMapping mapping; - mapping.map(reduce_op.body().front().getArgument(0), accumulator); - mapping.map(reduce_op.body().front().getArgument(1), rhs); - mapping.map(reduce_op.body().front().getArgument(2), accumulator); + mapping.map(reduce_op.body().getArgument(0), accumulator); + mapping.map(reduce_op.body().getArgument(1), rhs); + mapping.map(reduce_op.body().getArgument(2), accumulator); for (auto& nested : reduce_op.body().front().without_terminator()) { auto clone = rewriter.clone(nested, mapping); for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) { @@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { struct LhloLegalizeToGpuPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index af64c448ad9..57ea947c473 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -45,7 +45,7 @@ struct StaticMemRefCastOpConverter return failure(); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = @@ -96,7 +96,7 @@ struct DynamicMemRefCastOpConverter return failure(); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = @@ -217,8 +217,7 @@ struct ReshapeMemRefCastOpConverter SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, {target_desc}, sizes); - auto void_ptr_type = - LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()); + auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); Value ranked_desc_mem = rewriter.create( loc, void_ptr_type, sizes.front(), llvm::None); target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem); @@ -282,7 +281,7 @@ struct ReshapeMemRefCastOpConverter auto index_arg = cond_block->addArgument(typeConverter.getIndexType()); auto stride_arg = cond_block->addArgument(typeConverter.getIndexType()); auto pred = rewriter.create( - loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()), + loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), LLVM::ICmpPredicate::sge, index_arg, zero_index); Block *body_block = diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 00252735023..3d49027bb50 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -15,8 +15,6 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -31,6 +29,10 @@ namespace { class TestLhloToLLVMPass : public ::mlir::PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); @@ -39,8 +41,6 @@ class TestLhloToLLVMPass LLVMTypeConverter converter(&getContext()); populateStdToLLVMConversionPatterns(converter, patterns); PopulateLhloToLLVMConversionPatterns(&converter, &patterns); - populateLoopToStdConversionPatterns(patterns, &getContext()); - populateAffineToStdConversionPatterns(patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 19f47d08c0d..d9a2d993496 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -691,6 +691,10 @@ class SelectAndScatterOpConverter struct LhloLegalizeToParallelLoopsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { auto func = getFunction(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc new file mode 100644 index 00000000000..dba3cab6956 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -0,0 +1,199 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#define DEBUG_TYPE "mhlo-control-flow-to-scf" + +namespace mlir { +namespace mhlo { + +namespace { + +/// Convert MHLO While to SCF. +void MatchAndRewrite(WhileOp whileOp); + +/// Pass that converts MHLO control flow to SCF. +class ControlFlowToScfPass + : public mlir::PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { + getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); }); + } +}; + +// TODO(jpienaar): Look into reformulating as a pattern. +void MatchAndRewrite(WhileOp whileOp) { + // Handle pattern: + // x = start + // step = ... + // limit = ... + // while (x < limit) { ... x += step; } + + // Only handling multi value while loops at the moment. + auto tupleOp = whileOp.getOperand().getDefiningOp(); + if (!tupleOp) return; + auto bodyReturn = whileOp.body() + .front() + .getTerminator() + ->getOperand(0) + .getDefiningOp(); + // Note: due to the shape restrictions on While, if the operand to While is a + // tuple, then so is the return type of the body. But the verifier isn't + // checking that at the moment, so just bail out here if this doesn't hold. + if (!bodyReturn) return; + + Value result = whileOp.cond().front().getTerminator()->getOperand(0); + // TODO(jpienaar): Expand to handle more than simple case with LT compare and + // constant step. + auto cmp = result.getDefiningOp(); + if (!cmp || cmp.comparison_direction() != "LT") return; + + const int kConstant = -1; + auto getValueAndIndex = [&](Value val) -> std::pair { + if (matchPattern(val, m_Constant())) return {val, kConstant}; + // If it is defined by a tuple, then the tuple has to have been fed in and + // the external value is captured. + if (auto gte = val.getDefiningOp()) { + if (!gte.getOperand().isa()) return {nullptr, 0}; + int index = gte.index(); + return {tupleOp.getOperand(index), index}; + } + return {nullptr, 0}; + }; + + using ValueIndex = std::pair; + ValueIndex loopIndVar = getValueAndIndex(cmp.lhs()); + ValueIndex max = getValueAndIndex(cmp.rhs()); + if (!loopIndVar.first || !max.first) return; + auto add = + bodyReturn.getOperand(loopIndVar.second).getDefiningOp(); + if (!add) return; + ValueIndex step = getValueAndIndex(add.rhs()); + if (step.second != kConstant || !step.first) return; + + // Only handle case where tuple isn't propagated as is for now. + // TODO(jpienaar): Remove this when a tuple is also created inside the loop + // to propagate. + for (auto* use : whileOp.body().front().getArgument(0).getUsers()) + if (!isa(use)) return; + + LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = " + << max.second << " step = " << step.second << "\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = " + << max.first << " step = " << step.first << "\n";); + OpBuilder b(whileOp); + // Inputs to new for loop. + llvm::SmallVector input; + input.reserve(tupleOp.getNumOperands()); + for (auto r : tupleOp.getOperands().take_front(loopIndVar.second)) + input.push_back(r); + for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1)) + input.push_back(r); + + auto tensorIndexType = RankedTensorType::get({}, b.getIndexType()); + auto getAsIndex = [&](Value val) { + auto loc = whileOp.getLoc(); + return b.create( + loc, b.create(loc, tensorIndexType, val), ValueRange()); + }; + + // SCF for uses index type, so converted these. + auto forloopIndVar = getAsIndex(loopIndVar.first); + auto forMax = getAsIndex(max.first); + auto forStep = getAsIndex(step.first); + auto forOp = b.create(whileOp.getLoc(), forloopIndVar, + forMax, forStep, input); + // Transfer the body without the block arguments. + forOp.getLoopBody().front().getOperations().splice( + forOp.getLoopBody().front().getOperations().end(), + whileOp.body().front().getOperations()); + + b.setInsertionPointToStart(&forOp.getLoopBody().front()); + auto loopIndVarElType = + loopIndVar.first.getType().cast().getElementType(); + Value indVar = b.create( + whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType), + b.create(whileOp.getLoc(), loopIndVarElType, + forOp.getInductionVar())); + // Update all block argument users to the SCF For args. + for (auto* use : + llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) { + // TODO(jpienaar): Expand here too when we allow using the tuple in the + // loop. + auto gte = cast(use); + // If the loop induction var, then refer to the loop induction variable as + // this operand is not updated. + if (gte.index() == loopIndVar.second) { + use->getResult(0).replaceAllUsesWith(indVar); + use->erase(); + continue; + } + int index = gte.index(); + // If after the loop induction variable, then decrement as we don't include + // the loop induction variable in the for iter operands. + if (index > loopIndVar.second) --index; + use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]); + use->erase(); + } + + // Create new yield op without induction var update. + SmallVector newYieldOps; + newYieldOps.reserve(bodyReturn.getNumOperands() - 1); + for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second)) + newYieldOps.push_back(r); + for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1)) + newYieldOps.push_back(r); + // Delete return & tuple op. + forOp.getLoopBody().front().back().erase(); + forOp.getLoopBody().front().back().erase(); + b.setInsertionPointToEnd(&forOp.getLoopBody().front()); + b.create(whileOp.getLoc(), newYieldOps); + + // Recombine output tuple with max value of induction variable. + llvm::SmallVector loopOut; + loopOut.reserve(forOp.getNumResults() + 1); + for (auto r : forOp.getResults().take_front(loopIndVar.second)) + loopOut.push_back(r); + loopOut.push_back(max.first); + for (auto r : forOp.getResults().drop_front(loopIndVar.second)) + loopOut.push_back(r); + b.setInsertionPoint(whileOp); + auto newRes = b.create(whileOp.getLoc(), loopOut); + whileOp.replaceAllUsesWith(newRes.getOperation()); + whileOp.erase(); +} + +} // anonymous namespace + +std::unique_ptr> createControlFlowToScfPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 8db5d849322..4a17a5b5391 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -27,7 +28,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" namespace mlir { -namespace mhlo { namespace { // TODO(herhut): Generate these out of op definitions. @@ -46,115 +46,80 @@ namespace { sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftRightLogicalOp) sep fn(SubOp) -// TODO(frgossen): Make it variadic. +// TODO(herhut): Generate these out of op definitions. +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { - return llvm::all_of((op.getOperation())->getOperandTypes(), + return llvm::all_of(op.getOperation()->getOperandTypes(), [&](Type t) { return t.isa(); }); }); } -/// Unary element-wise operations on unranked tensors can be applied to the -/// flattened tensor with the same effect. -/// This pattern rewrites every such operation to +/// Element-wise operations on unranked tensors can be applied to the flattened +/// tensor operands with the same effect. This pattern rewrites every such +/// operation to /// (i) flatten the input tensor, -/// (ii) apply the unary operation, and +/// (ii) apply the operation, and /// (iii) restore the original shape. template -struct UnaryElementwiseOpConversion : public OpRewritePattern { - explicit UnaryElementwiseOpConversion(MLIRContext *context) +struct ElementwiseOpConversion : public OpRewritePattern { + explicit ElementwiseOpConversion(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Don't apply conversion to ops with statically shaped operands. - Value operand = op.getOperand(); - auto operandTy = operand.getType().dyn_cast(); - if (operandTy.hasRank()) return failure(); - - // Generate IR to flatten the operand. - auto loc = op.getLoc(); - Value shape = rewriter.create(loc, operand); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShapeAsDimTensor = - rewriter.create(loc, numElementsAsIndex); - auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - operandTy.getElementType()); - Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShapeAsDimTensor); - - // Generate IR for the actual operation. - Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); - - // Generate IR to restore the original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, operandTy, flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); - - return success(); - } -}; - -/// Binary element-wise operation on unranked tensors can be applied to the -/// flattened operand tensors with the same effect. -/// This pattern rewrites every such operation to -/// (i) flatten the operand tensors, -/// (ii) apply the binary operation, and -// (iii) restore the original shape. -template -struct BinaryElementwiseOpConversion : public OpRewritePattern { - explicit BinaryElementwiseOpConversion(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Don't apply conversion unless both operands are unranked. - if (op.lhs().getType().template isa() || - op.rhs().getType().template isa()) { + // Don't apply conversion unless all operands are unranked. + if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { + return operand.getType().isa(); + })) { return failure(); } - // Flatten operands. - Type shapeTy = shape::ShapeType::get(rewriter.getContext()); + // Get operands' shape. auto loc = op.getLoc(); - Value shapeLhs = rewriter.create(loc, op.lhs()); - Value shapeRhs = rewriter.create(loc, op.rhs()); - Value shape = rewriter.create(loc, shapeTy, - ValueRange{shapeLhs, shapeRhs}); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShape = - rewriter.create(loc, numElementsAsIndex); - TensorType lhsTy = op.lhs().getType().template cast(); - Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - lhsTy.getElementType()); - Value flatLhs = - rewriter.create(loc, flatLhsTy, op.lhs(), flatShape); - TensorType rhsTy = op.rhs().getType().template cast(); - Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rhsTy.getElementType()); - Value flatRhs = - rewriter.create(loc, flatRhsTy, op.rhs(), flatShape); + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + SmallVector operandShapes; + for (Value operand : op.getOperation()->getOperands()) { + Value shape = + rewriter.create(loc, extentTensorTy, operand); + operandShapes.push_back(shape); + } + Value shape = + operandShapes.size() == 1 + ? operandShapes.front() + : rewriter.create(loc, extentTensorTy, operandShapes); - // Apply actual operation to flattened operands. - Value flatResult = rewriter.create(loc, flatLhs, flatRhs); + // Derive flat shape. + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); + + // Flatten operands. + SmallVector flatOperands; + for (Value operand : op.getOperation()->getOperands()) { + Type operandElementTy = + operand.getType().template cast().getElementType(); + Type flatTy = + RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); + Value flat = rewriter.create(loc, flatTy, operand, + flatShape); + flatOperands.push_back(flat); + } + + // Apply operation to flattened operands. + Type resultElementTy = + op.getType().template cast().getElementType(); + Type flatResultTy = + RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); + Value flatResult = + rewriter.create(loc, flatResultTy, flatOperands, op.getAttrs()); // Restore original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, op.getType(), flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), + flatResult, shape); return success(); } @@ -162,24 +127,33 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { struct TransformUnrankedHloPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); target.addLegalOp(); -#define ADD_LEGAL(op) AddLegalOpOnRankedTensor(&target) - MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); - MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); -#undef ADD_LEGAL +#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) +#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) + MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;); + MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;); + MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;); +#undef ADD_LEGAL_MHLO +#undef ADD_LEGAL_CHLO + AddLegalOpOnRankedTensor(&target); + AddLegalOpOnRankedTensor(&target); // Populate rewrite patterns. OwningRewritePatternList patterns; PopulateTransformUnrankedHloPatterns(&ctx, &patterns); // Apply transformation. - if (failed(applyFullConversion(getFunction(), target, patterns))) + if (failed(applyPartialConversion(getFunction(), target, patterns))) return signalPassFailure(); } }; @@ -188,24 +162,26 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // TODO(frgossen): Populate all unary and binary operations. - // clang-format off -#define MAP_UNARY(op) UnaryElementwiseOpConversion -#define MAP_BINARY(op) BinaryElementwiseOpConversion +#define MAP_UNARY(op) ElementwiseOpConversion +#define MAP_BINARY(op) ElementwiseOpConversion +#define MAP_CHLO_UNARY(op) ElementwiseOpConversion #define COMMA , + // clang-format off patterns->insert< MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) - >(context); + MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), + MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA), + ElementwiseOpConversion, + ElementwiseOpConversion>(context); + // clang-format on #undef MAP_UNARY #undef MAP_BINARY +#undef MAP_CHLO_UNARY #undef COMMA - // clang-format on } -std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() { +std::unique_ptr createTransformUnrankedHloPass() { return std::make_unique(); } -} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 1458e5f3d63..9d072488389 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -122,7 +122,7 @@ class UnfuseBatchNormInferencePattern if (!fp_type) { return failure(); } - int64_t feature_dim = bn_op.feature_index().getSExtValue(); + int64_t feature_dim = bn_op.feature_index(); // Add epsilon to the variance and sqrt to get stddev: // stddev = sqrt(variance + epsilon) diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt new file mode 100644 index 00000000000..17e86f1caa8 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt @@ -0,0 +1,25 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +add_mlir_library(MLIRMhloUtils + broadcast_utils.cc + convert_op_folder.cc + cycle_detector.cc + hlo_utils.cc + + LINK_LIBS PUBLIC + MLIRSupport + ) diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc index a3ce4d44436..71b1a4e164f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/StandardTypes.h" @@ -46,9 +47,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, broadcast_dims.getIntValues().begin()); } -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder) { +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor) { auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) { @@ -57,15 +58,22 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, return nullptr; } - int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value lhs_shape_v = builder.createOrFold(loc, lhs); Value rhs_shape_v = builder.createOrFold(loc, rhs); - Value result_shape_v = builder.createOrFold( - loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v, - rhs_shape_v, nullptr /* error */); - return builder.createOrFold( - loc, RankedTensorType::get({result_rank}, builder.getIndexType()), - result_shape_v); + + if (unsafe_as_extent_tensor) { + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + Value result_shape_v = builder.createOrFold( + loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v, + rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); + } + + return builder.createOrFold( + loc, builder.getType(), lhs_shape_v, rhs_shape_v, + nullptr /* error */); } } // namespace hlo diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc index df2442cc4b6..0bbd91e0680 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc @@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { if (auto float_ty = ty.dyn_cast()) { APFloat value(float_ty.getFloatSemantics(), raw_value); return DenseElementsAttr::get(scalar_ty, value); + } else if (auto int_ty = ty.dyn_cast()) { + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); + } else if (auto complex_ty = ty.dyn_cast()) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } } - auto int_ty = ty.cast(); - APInt value(int_ty.getWidth(), static_cast(raw_value), true); - return DenseElementsAttr::get(scalar_ty, value); + llvm_unreachable("unsupported type"); +} + +static APFloat GetScalarLimitOfFloatType(FloatType float_ty, + ScalarLimit limit) { + auto &semantics = float_ty.getFloatSemantics(); + switch (limit) { + case kLowest: + return APFloat::getLargest(semantics, /*negative=*/true); + case kInfinityLowest: + return APFloat::getInf(semantics, /*negative=*/true); + case kMax: + return APFloat::getLargest(semantics, /*negative=*/false); + case kInfinityMax: + return APFloat::getInf(semantics, /*negative=*/false); + } + llvm_unreachable("invalid limit"); +} + +// Returns a scalar value for the given integer type. +// +// The argument 'scalar' describes which scalar value to return. `integer_value` +// is used to specify the integer value for kInteger. For any other scalar, +// integer_value is ignored. +static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty, + ScalarLimit limit) { + unsigned width = integer_ty.getWidth(); + switch (limit) { + case kLowest: + case kInfinityLowest: + if (integer_ty.isUnsigned()) { + return APInt::getMinValue(width); + } else { + return APInt::getSignedMinValue(width); + } + + case kMax: + case kInfinityMax: + if (integer_ty.isUnsigned()) { + return APInt::getMaxValue(width); + } else { + return APInt::getSignedMaxValue(width); + } + } + llvm_unreachable("invalid limit"); +} + +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = ty.dyn_cast()) { + return DenseElementsAttr::get(scalar_ty, + GetScalarLimitOfFloatType(float_ty, limit)); + } else if (auto integer_ty = ty.dyn_cast()) { + return DenseElementsAttr::get( + scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit)); + } + llvm_unreachable("unsupported type"); } } // namespace hlo diff --git a/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt new file mode 100644 index 00000000000..36a7eec5a1f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt @@ -0,0 +1,36 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +set(MLIR_HLO_TEST_DEPENDS + FileCheck count not + mlir-hlo-opt +) + +add_lit_testsuite(check-mlir-hlo-lit "Running the mlir-hlo regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${MLIR_HLO_TEST_DEPENDS} + ) +set_target_properties(check-mlir-hlo-lit PROPERTIES FOLDER "Tests") + +add_lit_testsuites(MLIR_HLO_OPT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${MLIR_HLO_TEST_DEPENDS}) + +add_dependencies(check-mlir-hlo check-mlir-hlo-lit) diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index f0fe52266f0..5da43d5f113 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { return %2 : tensor<2x2xi32> } +// CHECK-LABEL: constant_like_constant +func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { + // CHECK: mhlo.constant dense<3.200000e+00> + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// CHECK-LABEL: constant_like_constant_dynamic +func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> { + // CHECK: chlo.constant_like + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "mhlo.dynamic-slice" @@ -287,6 +301,13 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> { return %1 : tensor<4x1xi64> } +// CHECK-LABEL: slice_unknown_shape +func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: slice_concat_fold_first func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> @@ -561,3 +582,298 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } + +// CHECK-LABEL: unpack_repack_same_tuple +// CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) +func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token + %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> + + // CHECK: return [[ARG0]] + return %3 : tuple, !mhlo.token, tensor> +} + +// CHECK-LABEL: unpack_repack_same_tuple_single_element +// CHECK-SAME: ([[ARG0:%.*]]: tuple>) +func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> + + // CHECK: return [[ARG0]] + return %3 : tuple> +} + +// CHECK-LABEL: func @erase_dead_lhlo_constant +func @erase_dead_lhlo_constant() { + %M = alloc() : memref<256x1024xf32> + // CHECK-NEXT: return + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + dealloc %M : memref<256x1024xf32> + return +} + +// A negative test for dead lhlo constant op erasure. +// CHECK-LABEL: func @erase_dead_lhlo_constant_negative +func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { + // CHECK-NEXT: lmhlo.constant + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<4xf32>) -> () + // CHECK-NEXT: alloc + // CHECK-NEXT: lmhlo.constant + %N = alloc() : memref<256x1024xf32> + "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + return %N : memref<256x1024xf32> +} + +// CHECK-LABEL: func @fold_get_dimension_size +func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor { + %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor + return %size : tensor + // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor + // CHECK-NEXT: return %[[C]] +} + +// CHECK-LABEL: func @fold_select_same +func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { + %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_first +func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_second +func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg1 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_vector +func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.constant dense<1> : tensor<4xi1> + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: return %arg0 + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: gather_to_slice +func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { + %0 = constant dense<[1, 2]> : tensor<2xi32> + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[0, 2]> : tensor<2xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32> + return %1 : tensor<3x6x5xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32> + // CHECK: return %[[RET]] : tensor<3x6x5xf32> +} + +// CHECK-LABEL: gather_scalar_index_to_slice +func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> { + %0 = constant dense<1> : tensor + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[2]> : tensor<1xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor) -> tensor<5x6x4xf32> + return %1 : tensor<5x6x4xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32> + // CHECK: return %[[RET]] : tensor<5x6x4xf32> +} + +// CHECK-LABEL: func @fold_and_same +func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_ones +func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_zeros +func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constant +func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.and + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constants +func @fold_and_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_same +func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_ones +func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros +func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_constant +func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.or + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_right +func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_constants +func @fold_or_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_same +func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32> + // CHECK: return %0 + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_ones_left +func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_ones_right +func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_left +func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_right +func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_constants +func @fold_xor_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32> + // CHECK: return %0 + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_int +func @fold_negate_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[0, -1, -6, 3]> + %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_float +func @fold_negate_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> + %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_sqrt_f32_constants +func @fold_sqrt_f32_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.sqrt + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_sqrt_f64_constants +func @fold_sqrt_f64_constants() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> + // CHECK-NOT: mhlo.sqrt + return %1 : tensor<4xf64> +} + +// CHECK-LABEL: func @not_fold_sqrt_neg_constants +func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.sqrt + return %1 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir index 99aab532688..0738459f8b6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir @@ -1,19 +1,18 @@ -// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s // CHECK-LABEL: @broadcast_add // Note that all broadcast_ops are expanded from the same template, so // only test reification on an examplar op. // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor -func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> !shape.shape { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] - // CHECK: return %[[EXTENTS]] + // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor, tensor -> !shape.shape + // CHECK: return %[[BCAST_S]] : !shape.shape %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> - return %1 : tensor<1xindex> + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> !shape.shape + return %1 : !shape.shape } // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index c08ead5081e..af19a9b5c1c 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] @@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> @@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor @@ -253,7 +253,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // to a 1D tensor. // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -263,7 +263,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { // CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] // CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] -// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex> +// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor to tensor<1xindex> // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor @@ -288,7 +288,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // to a 1D tensor. // CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -296,7 +296,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] +// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor @@ -353,10 +353,12 @@ func @addUnrankedUnranked( // Handle rank 2 specialization // CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> -// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> -// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> +// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32> @@ -366,10 +368,12 @@ func @addUnrankedUnranked( // Handle rank 3 specialization // CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> -// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> -// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> +// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32> @@ -379,10 +383,12 @@ func @addUnrankedUnranked( // Handle rank 4 specialization // CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> -// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> -// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> +// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32> @@ -392,10 +398,12 @@ func @addUnrankedUnranked( // Handle rank 5 specialization // CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> -// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> -// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32> @@ -405,10 +413,12 @@ func @addUnrankedUnranked( // Handle rank 6 specialization // CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { // CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> -// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> -// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> +// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> +// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> // CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir new file mode 100644 index 00000000000..371e730c30b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s + +// Lower statically shaped `constant_like` to constant. +// CHECK-LABEL: @constant_like_static_shape +func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<1x2xf32> + // CHECK: return %[[RESULT]] + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor<1x2xi64>) -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// Lower dynamically shaped `constant_like` to broadcasted constant. +// CHECK-LABEL: constant_like_dynamic_shape +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @constant_like_dynamic_shape(%arg : tensor) -> tensor { + // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor + // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor to tensor<2xindex> + // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor) -> tensor + return %result : tensor +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 018711e33cb..960a769c388 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref %c1 = constant 1 : i64 - %shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64> + %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor @@ -320,6 +320,18 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @floor +func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "mhlo.floor"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + // BOTH-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -404,7 +416,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // BOTH: %[[C1:.*]] = constant 1 : index // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // BOTH: %[[C0_:.*]] = constant 0 : index // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -429,7 +441,7 @@ func @tanh_dyn(%arg0: tensor) { // BOTH: %[[C1:.*]] = constant 1 : index // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // BOTH: %[[C0_:.*]] = constant 0 : index // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -510,3 +522,16 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { : (tensor<1x8xf32>, tensor) -> tensor<1xf32> return %0 : tensor<1xf32> } + +// ----- + +// BOTH-LABEL: func @transpose +func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>} + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} + // BOTH-NOT: tensor_store + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 46725e0bd09..263ea1b4040 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -152,6 +152,16 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @floor +func @floor(%input: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: floorf + %0 = "mhlo.floor"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_neg func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -373,6 +383,40 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> + return %0 : tensor<1x784x1x1xf32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> + return %0 : tensor<1x4x1x512xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> + return %0 : tensor<4x1024x1x1xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %0 = "mhlo.minimum"(%lhs, %rhs) diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir similarity index 68% rename from tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir rename to tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir index 6cc07e0460c..ae61fc8477e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s // Check the validity of expected IR. // CHECK-LABEL: @sqr_transform_result @@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. %shape = shape.shape_of %a : tensor<*xf32> -> tensor %num_elements = shape.num_elements %shape : tensor -> index - %num_elements_as_index = shape.size_to_index %num_elements : index - %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %flat_shape = tensor_from_elements %num_elements : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. - %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor -> tensor - %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor - // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> @@ -73,16 +69,30 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]] // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] - // CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]]) + // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor - // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor + // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor + // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: return %[[B]] : tensor<*xf32> + %result = chlo.tan %a : tensor<*xf32> + return %result : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir index 37a61498fbf..abe4e872b73 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir @@ -42,6 +42,15 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } +// CHECK-LABEL: func @unary_ops_float +func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %0 = ceilf %arg0 : tensor<4xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: return %0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir new file mode 100644 index 00000000000..9c887a73a0f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s + +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tuple, tensor, tensor>) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1 = "mhlo.while"(%0) ( { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = mhlo.add %2, %cst_0 : tensor + %4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple, tensor, tensor>) -> tensor + %5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %6 = "mhlo.tuple"(%3, %4, %5) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%6) : (tuple, tensor, tensor>) -> () + }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + return %1 : tuple, tensor, tensor> +} + +// CHECK-LABEL: func @lt_loop( +// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor +// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor +// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor +// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor +// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor to tensor +// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor +// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor to tensor +// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor +// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor to tensor +// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]]) diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir deleted file mode 100644 index 3271595900d..00000000000 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir +++ /dev/null @@ -1,115 +0,0 @@ -// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s - -// CHECK-LABEL: func @remove_simple -func @remove_simple(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @remove_without_dealloc -func @remove_without_dealloc(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @replace_dependency -func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @keep_copies -func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.terminator"() : () -> () - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_not_be_removed -func @must_not_be_removed(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_first -func @must_be_removed_first(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_second -func @must_be_removed_second(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "lmhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @reduce -func @reduce(%arg0: memref<1x8xf32>, %arg1: memref, %arg2: memref<1xf32>) { - %0 = alloc() : memref<1xf32> - "lmhlo.reduce"(%arg0, %arg1, %0) ( { - // CHECK: ^bb0(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, - // CHECK-SAME: %[[ARG2:.*]]: memref) - ^bb0(%arg3: memref, %arg4: memref, %arg5: memref): - %1 = alloc() : memref - // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) - "lmhlo.add"(%arg3, %arg4, %1) - : (memref, memref, memref) -> () - // CHECK-NOT; lmhlo.copy - "lmhlo.copy"(%1, %arg5) : (memref, memref) -> () - "lmhlo.terminator"() : () -> () - }) {dimensions = dense<1> : tensor<1xi64>} - : (memref<1x8xf32>, memref, memref<1xf32>) -> () - "lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> () - return -} diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index 768d8da22bd..3162f37f912 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -496,6 +496,18 @@ func @sin(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @floor +func @floor(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "lmhlo.floor"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = floorf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () @@ -688,6 +700,46 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]] +// CHECK: linalg.copy + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>, + %arg1: memref<1x4x1x512xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>, + %arg1: memref<4x1024x1x1xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse @@ -722,3 +774,16 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-DAG: #[[TRANSPOSE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[TRANSPOSE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @transpose +func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + "lmhlo.transpose"(%arg0, %arg1) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]] diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir index 5bb1d475b24..45c383bd1d6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s // CHECK-LABEL: func @static_memref_cast func @static_memref_cast(%buf : memref<10x1x5xf32>) { diff --git a/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py b/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py new file mode 100644 index 00000000000..f81d47a76cd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py @@ -0,0 +1,82 @@ +"""Lit configuration to drive test in this repo.""" +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- Python -*- +# pylint: disable=undefined-variable + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import lit.formats +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst +import lit.util + +# Configuration file for the 'lit' test runner. + +# name: The name of this test suite. +config.name = 'MLIR_HLO_OPT' + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.mlir', '.mlir.py'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.mlir_hlo_obj_root, 'test') + +config.substitutions.append(('%PATH%', config.environment['PATH'])) +config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) + +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) + +llvm_config.use_default_substitutions() + +# excludes: A list of directories to exclude from the testsuite. The 'Inputs' +# subdirectories contain auxiliary inputs for various tests in their parent +# directories. +config.excludes = [ + 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt' +] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.mlir_hlo_obj_root, 'test') +config.mlir_hlo_tools_dir = os.path.join(config.mlir_hlo_obj_root, 'tools') + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) + +tool_dirs = [ + os.path.join(config.mlir_hlo_tools_dir, 'mlir-hlo-opt'), + config.llvm_tools_dir, +] +tools = [ + 'mlir-hlo-opt', + 'mlir-cpu-runner', + ToolSubst( + '%mlir_runner_utils_dir', + config.mlir_runner_utils_dir, + unresolved='ignore'), +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in b/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in new file mode 100644 index 00000000000..1555d314df0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.host_triple = "@LLVM_HOST_TRIPLE@" +config.target_triple = "@TARGET_TRIPLE@" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" +config.llvm_shlib_dir = "@SHLIBDIR@" +config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_exe_ext = "@EXEEXT@" +config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" +config.python_executable = "@PYTHON_EXECUTABLE@" +config.gold_executable = "@GOLD_EXECUTABLE@" +config.ld64_executable = "@LD64_EXECUTABLE@" +config.enable_shared = @ENABLE_SHARED@ +config.enable_assertions = @ENABLE_ASSERTIONS@ +config.targets_to_build = "@TARGETS_TO_BUILD@" +config.native_target = "@LLVM_NATIVE_ARCH@" +config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') +config.host_os = "@HOST_OS@" +config.host_cc = "@HOST_CC@" +config.host_cxx = "@HOST_CXX@" +# Note: ldflags can contain double-quoted paths, so must use single quotes here. +config.host_ldflags = '@HOST_LDFLAGS@' +config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" +config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' +config.host_arch = "@HOST_ARCH@" +config.mlir_hlo_src_root = "@CMAKE_SOURCE_DIR@" +config.mlir_hlo_obj_root = "@CMAKE_BINARY_DIR@" +config.mlir_runner_utils_dir = os.path.join(config.llvm_obj_root, "lib") + +# Support substitution of the tools_dir with user parameters. This is +# used when we can't determine the tool dir at configuration time. +try: + config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params + config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params +except KeyError: + e = sys.exc_info()[1] + key, = e.args + lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) + + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/tests/lit.cfg.py") diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir new file mode 100644 index 00000000000..d626f520824 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s + +// ----- +// CHECK-LABEL: @select +// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>, +func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) + -> tensor<2xi64> { + // CHECK: %[[C2:.*]] = constant 2 : i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> + // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: return %[[SHAPE]] : tensor<2xi64> + %0 = "mhlo.select"(%pred, %a, %b) + : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> + %1 = "mhlo_test.reify_return_type_shapes"(%0) + : (tensor<2x?xf32>) -> tensor<2xi64> + return %1 : tensor<2xi64> +} + +// ----- +// CHECK-LABEL: @compare +// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, +func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> { + // CHECK: %[[C2:.*]] = constant 2 : i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> + // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: return %[[SHAPE]] : tensor<2xi64> + %0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" } + : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> + %1 = "mhlo_test.reify_return_type_shapes"(%0) + : (tensor<2x?xi1>) -> tensor<2xi64> + return %1 : tensor<2xi64> +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index 212e79432b1..0120a7a5652 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -116,6 +116,30 @@ func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> // ----- +// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim +func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim +func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + +func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + // expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}} + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> @@ -456,7 +480,7 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4 // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor<5xf32>): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -468,7 +492,7 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -480,7 +504,7 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te // expected-error@+1 {{computation must return single output, but got: 0}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"() : () -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -492,7 +516,7 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4 // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> + %1 = mhlo.constant dense<2.0> : tensor<5xf32> "mhlo.return"(%1) : (tensor<5xf32>) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -504,7 +528,7 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5 // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2> : tensor} : tensor + %1 = mhlo.constant dense<2> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -576,6 +600,14 @@ func @recv_non_token_second_result(%token: !mhlo.token) -> tuple // ----- +// CHECK-LABEL: func @replica_id +func @replica_id() -> tensor { + %0 = "mhlo.replica_id"() : () -> tensor + return %0 : tensor +} + +// ----- + func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{but got 'tensor>'}} @@ -730,6 +762,14 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso // ----- +func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor, %start2: tensor, %start3: tensor) -> tensor<11x3x4xi32> { + // expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}} + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> + return %0 : tensor<11x3x4xi32> +} + +// ----- + // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> diff --git a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir index f903dbb7080..53ee94f8d1a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir @@ -109,7 +109,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements %[[DIM]] : tensor<1xindex> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor @@ -117,7 +117,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor diff --git a/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt new file mode 100644 index 00000000000..0f3d1c85795 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mlir-hlo-opt) diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt new file mode 100644 index 00000000000..69971f4c024 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -0,0 +1,34 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + ${dialect_libs} + ${conversion_libs} + MLIROptLib + + MhloRegisterDialects + AllMhloPasses + ) +add_llvm_executable(mlir-hlo-opt mlir-hlo-opt.cpp + DEPENDS + MLIRLmhloPassIncGen + MLIRMhloPassIncGen +) +llvm_update_compile_flags(mlir-hlo-opt) +target_link_libraries(mlir-hlo-opt PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(mlir-hlo-opt) diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index 70fc21d6959..d0c0e3c51e1 100644 --- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -13,109 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir-hlo/Dialect/mhlo/IR/register.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/FileUtilities.h" #include "mlir/Support/MlirOptMain.h" -// NOLINTNEXTLINE -static llvm::cl::opt inputFilename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt outputFilename( - "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt splitInputFile( - "split-input-file", - llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyDiagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyPasses( - "verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), - llvm::cl::init(true)); - -// NOLINTNEXTLINE -static llvm::cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt showDialects( - "show-dialects", llvm::cl::desc("Print the list of registered dialects"), - llvm::cl::init(false)); - int main(int argc, char **argv) { - mlir::registerAllDialects(); mlir::registerAllPasses(); - - mlir::mhlo::registerAllDialects(); mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); - llvm::InitLLVM y(argc, argv); + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); - // Register any pass manager command line options. - mlir::registerPassManagerCLOptions(); - mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); - - // Parse pass names in main to ensure static initialization completed. - llvm::cl::ParseCommandLineOptions(argc, argv, - "MLIR modular optimizer driver\n"); - - if (showDialects) { - mlir::MLIRContext context; - llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { - llvm::outs() << dialect->getNamespace() << "\n"; - } - return 0; - } - - // Set up the input file. - std::string errorMessage; - auto file = mlir::openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - auto output = mlir::openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - exit(1); - } - - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects))) { - return 1; - } - // Keep the output file if the invocation of MlirOptMain was successful. - output->keep(); - return 0; + return failed( + mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 555c11779f5..aee6cd5ad91 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -29,6 +29,7 @@ filegroup( "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], @@ -227,6 +228,7 @@ cc_library( "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffects", @@ -237,6 +239,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "constant_utils", + srcs = [ + "utils/constant_utils.cc", + ], + hdrs = [ + "utils/constant_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "//tensorflow/stream_executor/lib", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "lstm_utils", srcs = [ @@ -256,6 +280,28 @@ cc_library( ], ) +cc_library( + name = "nms_utils", + srcs = [ + "utils/nms_utils.cc", + ], + hdrs = [ + "utils/nms_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/core:framework", + "@flatbuffers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tftext_utils", srcs = [ @@ -347,7 +393,9 @@ cc_library( "transforms/passes.h", ], deps = [ + ":constant_utils", ":lstm_utils", + ":nms_utils", ":stateful_ops_utils", ":tensorflow_lite", ":tftext_utils", @@ -359,6 +407,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass", @@ -477,25 +526,13 @@ gentbl( tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", td_srcs = [ + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "ir/tfl_op_interfaces.td", ], ) -# Library with tensorflow Lite dialect static initialization. -cc_library( - name = "tensorflow_lite_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":tensorflow_lite", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) - tf_native_cc_binary( name = "converter-gen", srcs = [ @@ -602,12 +639,10 @@ cc_library( ":flatbuffer_tflite_operator_lib", ":stateful_ops_utils", ":tensorflow_lite", - ":tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:protos_all_cc", @@ -646,7 +681,7 @@ cc_library( ":convert_type", ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", - ":tensorflow_lite_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", @@ -714,16 +749,13 @@ cc_library( ], deps = [ ":flatbuffer_translate_lib", + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", @@ -736,7 +768,7 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_registeration", # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) @@ -788,7 +820,7 @@ tf_cc_binary( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -812,19 +844,18 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", ":flatbuffer_translate_registeration", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:StandardOps", ], ) @@ -844,14 +875,13 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:core_cpu_base", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", @@ -885,7 +915,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index edead2037a3..44eba0d5e6f 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -513,7 +513,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { continue; } if (trait.getDef().getValueAsString("trait") != - "OpTrait::TFLRuntimeOpTrait") { + "::mlir::OpTrait::TFLRuntimeOpTrait") { continue; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 89fae87cb25..34200fb88b6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" @@ -133,63 +134,59 @@ static StatusOr GetTFLiteType(Type type, return Status(error::INVALID_ARGUMENT, "'isSigned' can only be set for 8-bits integer type"); } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::StandardTypes::F64: - return tflite::TensorType_FLOAT64; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - if (ftype && ftype.isF64()) { - return tflite::TensorType_COMPLEX128; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + + if (type.isF32()) { + return tflite::TensorType_FLOAT32; + } else if (type.isF16()) { + return tflite::TensorType_FLOAT16; + } else if (type.isF64()) { + return tflite::TensorType_FLOAT64; + } else if (type.isa()) { + return tflite::TensorType_STRING; + } else if (type.isa()) { + return tflite::TensorType_UINT8; + } else if (auto complex_type = type.dyn_cast()) { + auto ftype = complex_type.getElementType(); + if (ftype.isF32()) { + return tflite::TensorType_COMPLEX64; } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } + if (ftype.isF64()) { + return tflite::TensorType_COMPLEX128; } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } else if (auto q_uniform_type = + type.dyn_cast()) { + return GetTFLiteType(q_uniform_type.getStorageType(), + q_uniform_type.isSigned()); + + } else if (auto q_peraxis_type = + type.dyn_cast()) { + return GetTFLiteType(q_peraxis_type.getStorageType(), + q_peraxis_type.isSigned()); + } else if (type.isa()) { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; } + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); } static bool IsConst(Operation* op) { @@ -358,8 +355,13 @@ class Translator { if (emit_custom_ops) { enabled_op_types_.emplace(OpType::kCustomOp); } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + tf_dialect_ = + module.getContext()->getOrLoadDialect(); + tfl_dialect_ = module.getContext() + ->getOrLoadDialect(); + // Right now the TF executor dialect is still needed to build NodeDef. + module.getContext() + ->getOrLoadDialect(); } Optional TranslateInternal(); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 3c8bf26aa14..62eaffa8ed9 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -254,20 +255,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, layer_stats, axis_stats, axis); } -StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { - if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { +// Returns true if this is a basic LSTM op. +bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { + if (const auto* op = op_union.AsLSTMOptions()) { + return op->kernel_type == tflite::LSTMKernelType_BASIC; + } else { + return false; + } +} + +// Gets the MLIR op name with the dialect name for the flatbuffer operator. +StatusOr GetMlirOpName(const tflite::OperatorT& op, + const tflite::OperatorCodeT& op_code) { + if (IsBasicLSTMOp(op.builtin_options)) { + return std::string("tfl.basic_lstm"); + } + + if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { return std::string("tfl.custom"); } - if (opcode.builtin_code == tflite::BuiltinOperator_IF) { + if (op_code.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); } - if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) { + if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) { return std::string("tf.While"); } - const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code); - std::string lowered_name = llvm::StringRef(op_name).lower(); - return llvm::Twine("tfl.", lowered_name).str(); + llvm::StringRef op_name( + tflite::EnumNameBuiltinOperator(op_code.builtin_code)); + return llvm::Twine("tfl.", op_name.lower()).str(); } // The buffers in TFLite flatbuffers have their contents stored as a vector of @@ -464,7 +480,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, value = mlir::DenseStringElementsAttr::get(shaped_type, refs); } else if (elem_type.isa()) { - auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); + auto dialect = elem_type.getContext()->getLoadedDialect("tf"); tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); @@ -510,14 +526,6 @@ llvm::SmallVector ConvertSubgraphIdxsToFunctionAttrs( return {}; } -// Returns true if this is a basic LSTM op. -bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { - if (const auto* op = op_union.AsLSTMOptions()) { - return op->kernel_type == tflite::LSTMKernelType_BASIC; - } else { - return false; - } -} // TODO(krzysd) Handle function calls StatusOr ConvertOp( @@ -525,7 +533,6 @@ StatusOr ConvertOp( const std::vector& intermediate_types, Value optional_arg_marker, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder) { @@ -537,10 +544,10 @@ StatusOr ConvertOp( return emitError(loc, err.ToString()), err; } - const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options); - const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index); - const std::string& op_name = - is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index); + const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index); + + TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code)); + OperationState op_state(loc, op_name); for (auto input_num : op.inputs) { @@ -777,7 +784,7 @@ static StatusOr PostProcessFuncOp(FuncOp func) { auto new_output_type = new_qtype.castFromExpressedType( mlir::quant::UniformQuantizedType::castToExpressedType( value.getType())); - builder.setInsertionPointAfter(cst); + builder.setInsertionPointAfter(cst.getOperation()); auto new_op = builder.create( cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type), cst.valueAttr()); @@ -791,8 +798,7 @@ static StatusOr PostProcessFuncOp(FuncOp func) { } // Build a FuncOp from a tflite SubGraph -// The op_names are a mapping from indexes into the TFLite operators array to -// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken +// The buffers are directly taken // from the deserialized flatbuffer as we do not have the type information to // interpret them until this point. The base_loc parameter is the location of // the flatbuffer as a whole (usually a file). The is_entry_point flag @@ -802,7 +808,6 @@ static StatusOr PostProcessFuncOp(FuncOp func) { StatusOr ConvertSubgraph( const tflite::SubGraphT& subgraph, llvm::StringRef name, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& buffers, Location base_loc, Builder builder, bool is_entry_point, @@ -1002,8 +1007,7 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( auto* mlir_op, ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker, - op_codes, op_names, func_names, subgraph.tensors, op_loc, - op_builder)); + op_codes, func_names, subgraph.tensors, op_loc, op_builder)); // Add the results to the value maps. There are two cases: 1. the result // tensor does not have min/max values, the original op result is used @@ -1069,6 +1073,10 @@ OwningModuleRef tflite::FlatBufferToMlir( const std::vector& ordered_input_arrays, const std::vector& ordered_output_arrays, bool experimental_prune_unreachable_nodes_unconditionally) { + context->loadDialect< + mlir::StandardOpsDialect, mlir::quant::QuantizationDialect, + mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>(); + auto model_ptr = FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); if (nullptr == model_ptr) { @@ -1079,17 +1087,6 @@ OwningModuleRef tflite::FlatBufferToMlir( auto builder = Builder(context); - std::vector operator_names; - operator_names.reserve(model->operator_codes.size()); - - for (auto& opcode : model->operator_codes) { - auto operator_name_or_error = OpNameForOpCode(*opcode); - if (!operator_name_or_error.ok()) { - return emitError(base_loc, operator_name_or_error.status().ToString()), - nullptr; - } - operator_names.push_back(operator_name_or_error.ConsumeValueOrDie()); - } std::vector func_names; for (auto& subgraph : model->subgraphs) { @@ -1110,8 +1107,8 @@ OwningModuleRef tflite::FlatBufferToMlir( auto& subgraph = e.value(); std::string name = SubgraphName(e.index(), *subgraph); auto func_or_error = ConvertSubgraph( - *subgraph, name, model->operator_codes, operator_names, func_names, - model->buffers, base_loc, builder, + *subgraph, name, model->operator_codes, func_names, model->buffers, + base_loc, builder, // TODO(b/131175224,b/132239787) Support multiple entry points /*is_entry_point=*/e.index() == 0, /*use_external_constant=*/use_external_constant, ordered_input_arrays, diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index ceaa4e215cf..60fd1160be2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -95,50 +95,44 @@ static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter( static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter( mlir::Type type, flatbuffers::FlatBufferBuilder* builder) { - switch (type.getKind()) { - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::StandardTypes::Complex: { - auto etype = type.cast().getElementType(); - if (etype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - llvm_unreachable("invalid complex Type in conversion"); + if (type.isF16()) { + return tflite::TensorType_FLOAT16; + } else if (type.isF32()) { + return tflite::TensorType_FLOAT32; + } else if (type.isa()) { + return tflite::TensorType_STRING; + } else if (auto complex_type = type.dyn_cast()) { + if (complex_type.getElementType().isF32()) { + return tflite::TensorType_COMPLEX64; } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - default: - llvm_unreachable("invalid integer Type in conversion"); - } + llvm_unreachable("invalid complex Type in conversion"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + default: + llvm_unreachable("invalid integer Type in conversion"); } - default: - llvm_unreachable("invalid Type in conversion"); } + llvm_unreachable("invalid Type in conversion"); } // I32Attr already returns an int as required by flatbuffer builders. static int ConvertI32AttrForOptionWriter( - llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) { - return i.getSExtValue(); + int i, flatbuffers::FlatBufferBuilder* builder) { + return i; } static int ConvertPositiveI32AttrForOptionWriter( - llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) { + int i, flatbuffers::FlatBufferBuilder* builder) { return ConvertI32AttrForOptionWriter(i, builder); } @@ -255,7 +249,7 @@ Status mlir::CustomOptionsToAttributes( {static_cast(custom_options.size())}, builder.getIntegerType(8)); attributes->emplace_back(builder.getNamedAttr( "custom_option", - OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"), + OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"), type, content))); return Status::OK(); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 5b95b30a96c..94f7e2261f7 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -33,6 +34,8 @@ limitations under the License. #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" using llvm::cl::opt; @@ -175,5 +178,11 @@ static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( - "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); + "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction, + [](DialectRegistry& registry) { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + }); } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index ae1e3ebe5e6..2894af9b97e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -40,10 +41,10 @@ limitations under the License. #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { -#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" namespace TFL { // Returns true when the given operand arguments have the same shape or @@ -253,9 +254,8 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowLiteOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -269,13 +269,13 @@ struct TensorFlowLiteOpFolderDialectInterface }; TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context) - : Dialect(/*name=*/"tfl", context) { + : Dialect(/*name=*/"tfl", context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" >(); addInterfaces(); + TensorFlowLiteDialectFoldInterface>(); } //===----------------------------------------------------------------------===// @@ -569,7 +569,7 @@ namespace { int64_t GetConcatenationOpAxis(ConcatenationOp op) { auto output_type = op.output().getType().cast(); - int64_t axis = op.axis().getSExtValue(); + int32_t axis = op.axis(); if (axis < 0) axis += output_type.getRank(); return axis; } @@ -1027,10 +1027,13 @@ static LogicalResult Verify(PackOp op) { // Check axis bounds. if (input_type.hasRank()) { - int64_t axis_value = op.axis().getSExtValue(); - if (abs(axis_value) > input_type.getRank()) - return op.emitOpError("op attribute 'axis' is out of bounds, got ") - << axis_value; + int32_t axis_value = op.axis(); + if (axis_value < 0) axis_value += input_type.getRank() + 1; + if (axis_value < 0 || axis_value >= input_type.getRank() + 1) + return op.emitOpError() + << "op attribute 'axis' should be in range [-rank - 1, rank + 1), " + << "got rank = " << input_type.getRank() + << ", and axis = " << op.axis(); } // Make sure all inputs have the same shape and element type. @@ -1443,12 +1446,59 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // TODO(b/133486129): Implement shape inference for unpack -static LogicalResult Verify(UnpackOp op) { - // TODO(antiagainst): Implement other checks as in - // tensorflow/lite/kernels/unpack.cc +LogicalResult UnpackOp::inferReturnTypes( + MLIRContext *context, Optional loc, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + UnpackOpAdaptor op(operands, attributes); + // TODO(jpienaar): Refactor verify + if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context)))) + return failure(); - if (op.getOperation()->getNumResults() != op.num()) - return op.emitOpError("output count should match 'num' attribute"); + if (operands.size() != 1) { + return emitOptionalError(loc, "input count should be equal to 1"); + } + + const int64_t num_value = op.num().getInt(); + auto input_type = operands[0].getType().dyn_cast(); + if (!input_type || !input_type.hasRank()) { + // If input is unranked, then so is output. + inferredReturnTypes.assign( + num_value, UnrankedTensorType::get(input_type.getElementType())); + return success(); + } + + if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) { + return emitOptionalError( + loc, "number of elements in input shoule be larger than 0"); + } + + const int64_t rank = input_type.getRank(); + if (rank <= 0) { + return emitOptionalError(loc, "input should be of rank larger than 0"); + } + + int64_t axis_value = op.axis().getInt(); + if (axis_value < 0) { + axis_value += rank; + } + if (axis_value < 0 || axis_value >= rank) { + return emitOptionalError( + loc, "attribute 'axis' should be in range [-rank, rank), got axis = ", + op.axis().getInt(), ", and rank = ", rank); + } + + if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) && + input_type.getDimSize(axis_value) != num_value) { + return emitOptionalError(loc, "output count should match 'num' attribute"); + } + + auto output_shape = llvm::to_vector<4>(input_type.getShape()); + output_shape.erase(output_shape.begin() + axis_value); + + auto output_type = + RankedTensorType::get(output_shape, input_type.getElementType()); + inferredReturnTypes.assign(num_value, output_type); return success(); } @@ -1495,7 +1545,7 @@ static LogicalResult VerifySplitOpOutputTypes( } static LogicalResult Verify(SplitOp op) { - int64_t num_splits = op.num_splits().getSExtValue(); + int64_t num_splits = op.num_splits(); if (op.getNumResults() != num_splits) return op.emitOpError("output count should match 'num_splits' attribute"); @@ -1531,7 +1581,7 @@ static LogicalResult Verify(SplitOp op) { } static LogicalResult Verify(SplitVOp op) { - int64_t num_splits = op.num_splits().getSExtValue(); + int64_t num_splits = op.num_splits(); if (op.getNumResults() != num_splits) return op.emitOpError("output count should match 'num_splits' attribute"); @@ -2327,8 +2377,16 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { //===----------------------------------------------------------------------===// #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc" + +} // namespace TFL +} // namespace mlir + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" + +namespace mlir { +namespace TFL { + #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc" Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index caed0bb3ad9..589f18d789d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -26,14 +26,15 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { -#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc" namespace TFL { class TensorFlowLiteDialect : public Dialect { @@ -49,10 +50,11 @@ class TensorFlowLiteDialect : public Dialect { }; #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" } // end namespace TFL } // end namespace mlir +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6dc9fda656f..1b91c0dbe61 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -19,6 +19,7 @@ limitations under the License. #define TFL_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" @@ -38,7 +39,7 @@ def TFL_Dialect : Dialect { represented using zero-dimensional tensors); }]; - let cppNamespace = "TFL"; + let cppNamespace = "::mlir::TFL"; } //===----------------------------------------------------------------------===// @@ -107,7 +108,11 @@ def OpaqueBytesAttr : ElementsAttrBase< ".getElementType().isInteger(8)">, ]>, "opaque bytes attribute" - >; + > { + let storageType = [{ OpaqueElementsAttr }]; + let returnType = [{ OpaqueElementsAttr }]; + let convertFromStorage = "$_self"; +} //===----------------------------------------------------------------------===// // Derived shape attribute class. @@ -2442,8 +2447,7 @@ def TFL_ReluOp: TFL_Op<"relu", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu operator"; let description = [{ @@ -2471,8 +2475,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu6 operator"; let description = [{ @@ -2500,8 +2503,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultShape]> { let summary = "Relu1 operator"; let description = [{ @@ -3024,7 +3026,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", [ def TFL_UnpackOp : TFL_Op<"unpack", [ NoSideEffect, SameOperandsAndResultElementType, - SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale, + DeclareOpInterfaceMethods]> { let summary = "Unpacks a tensor along a dimension into multiple tensors"; let description = [{ @@ -3047,7 +3050,7 @@ def TFL_UnpackOp : TFL_Op<"unpack", [ let arguments = (ins TFL_TensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$input, - I32Attr:$num, + Confined:$num, I32Attr:$axis ); @@ -3055,8 +3058,6 @@ def TFL_UnpackOp : TFL_Op<"unpack", [ TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs ); - let verifier = [{ return Verify(*this); }]; - let hasOptions = 1; } diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 0d42fbb9646..35a58a01a29 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -30,12 +30,16 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" @@ -98,6 +102,10 @@ int main(int argc, char** argv) { // Load the MLIR module. mlir::MLIRContext context; + context.getDialectRegistry() + .insert(); + llvm::SourceMgr source_mgr; source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context)); diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 6299a70b1df..7e7d4678a87 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -62,6 +62,10 @@ class ImportQuantStatsPass void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + // Parses the serialized quant stats protobuf and initialize the internal // data structure. This method must be called after the pass is created. bool ParseQuantStats(const std::string &stats_str); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 31c0e4cb8a9..38c7ad86e05 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -28,6 +28,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -74,6 +75,6 @@ tf_cc_binary( "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index a2e3c065113..238710bcf13 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -52,6 +53,7 @@ TfLiteStatus QuantizeModel( } MLIRContext context; + context.getDialectRegistry().insert(); StatusScopedDiagnosticHandler statusHandler(&context, /*propagate=*/true); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 9e0ad990657..16b51496b5f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -99,12 +99,14 @@ class QuantizationDriver { public: explicit QuantizationDriver(FuncOp fn, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter) + OpQuantSpecGetter op_quant_spec_getter, + bool enforce_fixed_output_range) : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed), disable_per_channel_(disable_per_channel), - op_quant_spec_getter_(op_quant_spec_getter) {} + op_quant_spec_getter_(op_quant_spec_getter), + enforce_fixed_output_range_(enforce_fixed_output_range) {} // The entry point of the quantization parameters propagation. void Run(); @@ -354,6 +356,8 @@ class QuantizationDriver { llvm::SmallVector args_; OpQuantSpecGetter op_quant_spec_getter_; + + bool enforce_fixed_output_range_; }; } // namespace @@ -794,7 +798,8 @@ bool QuantizationDriver::PropagateParams() { } // TODO(fengliuai): make the bit width configurable. - if (auto restricted = llvm::dyn_cast(op)) { + auto restricted = llvm::dyn_cast(op); + if (restricted && enforce_fixed_output_range_) { // TODO(fengliuai): different result can have different fixed range. auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8); for (auto i = 0; i < op->getNumResults(); ++i) { @@ -864,10 +869,12 @@ void QuantizationDriver::Run() { } } -void ApplyQuantizationParamsPropagation( - mlir::FuncOp func, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter) { - QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter) +void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, + bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool post_training_quantization) { + QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter, + post_training_quantization) .Run(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 07e5ba4e879..eb9843f6e4a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -106,9 +106,9 @@ struct ConvertStatsToQDQs : public OpRewritePattern { mins.push_back(FloatAttr::getValueAsDouble(*it++)); maxs.push_back(FloatAttr::getValueAsDouble(*it)); } - quant_type = quant::fakeQuantAttrsToType( - op.getLoc(), num_bits, op.axis()->getSExtValue(), mins, maxs, - narrow_range, expressed, is_signed); + quant_type = + quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins, + maxs, narrow_range, expressed, is_signed); } else if (auto stats = op.layerStats().dyn_cast()) { double rmin = FloatAttr::getValueAsDouble(stats.getValue({0})); double rmax = FloatAttr::getValueAsDouble(stats.getValue({1})); @@ -119,7 +119,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { return failure(); } - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(op.getOperation()); Type result_type = quant_type.castFromExpressedType(op.getType()); auto q = rewriter.create(op.getLoc(), result_type, op.arg()); auto dq = rewriter.create(op.getLoc(), op.getType(), q); @@ -490,9 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( // and the propagation results are materialized by inserting pairs of quantize // and dequantize ops to this function. Set `disable_per_channel` to true to not // use per channel quantization even the op supports it. +// Setting `enforce_fixed_output_range` to true, to infer quantization +// parameters from the fixed output range ops. This is only used for +// post-training quantization. void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter); + OpQuantSpecGetter op_quant_spec_getter, + bool enforce_fixed_output_range); // The function might contain more stats ops than required, and it will // introduce requantize if the calibration stats have conflicts. This method diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 0826b3265f6..b043834188c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -106,9 +106,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. - rewriter.setInsertionPointAfter(tf_op); - IntegerAttr num_bits = - rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue()); + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 208fb4c8a56..fc56ad05535 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -55,7 +55,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { for (const auto t : op.getTraits()) { if (auto opTrait = llvm::dyn_cast(&t)) { auto trait = opTrait->getTrait(); - if (!trait.consume_front("OpTrait::quant::")) continue; + if (!trait.consume_front("::mlir::OpTrait::quant::")) continue; OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName() << ">(op)) {\n"; @@ -65,7 +65,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n"; OUT(6) << "spec->restricted_output_params[std::make_pair(" << matches[1] << ", " << matches[2] - << ")].push_back(tfl.OpTrait::quant::" << trait << "<" + << ")].push_back(tfl.::mlir::OpTrait::quant::" << trait << "<" << op.getQualCppClassName() << ">::GetResultQuantizedType(i));\n"; matches.clear(); diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt new file mode 100644 index 00000000000..5f498a404a9 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt @@ -0,0 +1,232 @@ +# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 8 + } + dim { + size: 8 + } + dim { + size: 2 + } + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/w" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 2 + } + dim { + size: 2 + } + } + tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374" + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/w/read" + op: "Identity" + input: "conv_net_2d/conv_2d_0/w" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@conv_net_2d/conv_2d_0/w" + } + } + } +} +node { + name: "conv_net_2d_1/conv_2d_0/convolution" + op: "Conv2D" + input: "input" + input: "conv_net_2d/conv_2d_0/w/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NCHW" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "explicit_paddings" + value { + list { + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "conv_net_2d/conv_2d_0/b" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\315\314\314=\315\314\314=" + } + } + } +} +node { + name: "conv_net_2d/conv_2d_0/b/read" + op: "Identity" + input: "conv_net_2d/conv_2d_0/b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@conv_net_2d/conv_2d_0/b" + } + } + } +} +node { + name: "conv_net_2d_1/conv_2d_0/BiasAdd" + op: "BiasAdd" + input: "conv_net_2d_1/conv_2d_0/convolution" + input: "conv_net_2d/conv_2d_0/b/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "conv_net_2d_1/Relu" + op: "Relu" + input: "conv_net_2d_1/conv_2d_0/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "output_0" + op: "Identity" + input: "conv_net_2d_1/Relu" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +library { +} + +# CHECK: 'main' inputs: +# CHECK-NEXT: name: 'input' +# CHECK-NEXT: 'main' outputs: +# CHECK-NEXT: name: 'output_0' diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt index f482e3db6b9..a7f6040f211 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s +# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s node { name: "tf.Less" op: "Less" diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir index f6f32e7a069..138614d81e6 100644 --- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -3435,4 +3435,19 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_ } // CHECK: func @ngrams_ragged_rank_2(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor {tf._user_specified_name = "args_1"}) -> (tensor, tensor<3xi64>, tensor) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape, #tf.shape<3>, #tf.shape], tf.signature.is_stateful} { // CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor, tensor<3xi64>, tensor) -> (tensor, tensor<3xi64>, tensor) -// CHECK: return %0#0, %0#1, %0#2 : tensor, tensor<3xi64>, tensor \ No newline at end of file +// CHECK: return %0#0, %0#1, %0#2 : tensor, tensor<3xi64>, tensor + + +func @sgnn_projection(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor {tf._user_specified_name = "row_splits"}) -> tensor attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64> + %1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor) -> tensor + %2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor, tensor<10x1xi64>) -> tensor<10x?xf64> + %3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor + return %4 : tensor +} + + +// CHECK: func @sgnn_projection(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor {tf._user_specified_name = "row_splits"}) -> tensor attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape, #tf.shape], tf.signature.is_stateful} { +// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor, tensor) -> tensor +// CHECK: return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 90266b4e78e..3c390df74b4 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -1,12 +1,11 @@ -// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor) -> tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> // CHECK: return [[MUL]] : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7cb9c4dd22c..3a2a0a8b9d2 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s +// RUN: tf-opt %s -tfl-legalize-tf --cse | FileCheck %s func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> @@ -196,7 +196,6 @@ func @shape(%arg0: tensor) -> tensor<2xi32> { // CHECK-LABEL: shape // CHECK: "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> } func @fill(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { @@ -719,9 +718,8 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> -// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -753,9 +751,8 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> -// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -1029,14 +1026,48 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) } -func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} : +(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul +// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_0:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} : +(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_a +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_2:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} : (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> return %0 : tensor<40x40xf32> -// CHECK-LABEL: matmul_transposed +// CHECK-LABEL: matmul_transposed_b // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } +func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} : +(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_ab +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> @@ -1324,10 +1355,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[CST_0:.*]] = constant unit // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> - // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> - // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[CST_2:.*]] = constant unit - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } @@ -1482,28 +1510,6 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) { // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } -func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32 -// CHECK: [[CST:%.*]] = constant dense<1> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> @@ -1555,3 +1561,27 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x // CHECK-LABEL: add_with_int32_5d_inputs // CHECK: "tf.Add"(%arg0, %arg1) } + +func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_int32_perm + // CHECK: "tfl.transpose" +} + +func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_int64_perm + // CHECK: "tfl.transpose" +} + +func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> { + %0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_arg + // CHECK: "tfl.transpose" +} + diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir deleted file mode 100644 index 7e9f66baa90..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s - -func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { -^bb0(%arg0: tensor<3x2xi32>): - // CHECK: error: 'unknown_op' op dialect is not registered - %0 = "unknown_op"(%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> - return %0 : tensor<3x2xi32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 7ef6997f938..b62f5655183 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1139,9 +1139,15 @@ func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x // ----- -func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { +func @packNegInputAxis2(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x2x4xi32> { // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} - %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> + %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x2x4xi32> + return %0 : tensor<1x2x4xi32> +} + +func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> return %0 : tensor<2x1x4xi32> } @@ -1172,7 +1178,7 @@ func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - // expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}} + // expected-error @+1 {{op attribute 'axis' should be in range [-rank - 1, rank + 1), got rank = 1, and axis = 3}} %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -1183,7 +1189,22 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) return %0#0 : tensor<2xi32> +} +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} + %0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> { + // CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} + %0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + return %0#0 : tensor<3xi32> } // ----- @@ -1204,6 +1225,45 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}} + %0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}} + %0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor) -> tensor<2xi32> { + // expected-error @+1 {{input should be of rank larger than 0}} + %0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{op inferred type incompatible with return type of operation}} + %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { + %0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) + return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32> +} + +// ----- + // CHECK-LABEL: testMean func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> { // CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false} @@ -1640,6 +1700,15 @@ func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform> // ----- +func @testReluWithDifferentScales(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + %2 = "tfl.relu6"(%1) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %2 : tensor<10x!quant.uniform> +} + +// ----- + func @testEmbeddingLookup(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index b8be96a9159..8d64bc6ed0a 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -50,6 +50,96 @@ func @fuseSubIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) } +// CHECK-LABEL: fuseAddIntoTransposeConv +func @fuseAddIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00, 2.500000e+00, 3.500000e+00]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseSubIntoTransposeConv +func @fuseSubIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[-5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01, -5.000000e-01, 5.000000e-01]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseAddIntoTransposeConvNoBias +func @fuseAddIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant unit + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.000000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<1.500000e+00> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseMulIntoTransposeConv +func @fuseMulIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.500000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant dense<[1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00, 1.500000e+00, 3.000000e+00]> : tensor<32xf32> + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + +// CHECK-LABEL: fuseMulIntoTransposeConvNoBias +func @fuseMulIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant dense<1.5> : tensor<32xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst_1 = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_2 = constant dense<1.0> : tensor<32x4x4x128xf32> + %cst_3 = constant unit + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + return %1 : tensor<1x64x84x32xf32> + + // CHECK: %[[SHAPE:.*]] = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + // CHECK: %[[WEIGHTS:.*]] = constant dense<1.500000e+00> : tensor<32x4x4x128xf32> + // CHECK: %[[BIAS:.*]] = constant unit + // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) + // CHECK: return %[[RESULT]] +} + // CHECK-LABEL: fuseAddIntoFollowingConv2d func @fuseAddIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> { %cst = constant dense<1.5> : tensor @@ -1066,3 +1156,138 @@ func @DontConvertSqueezeToReshape(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %[[RESULT]] } +func @ConvertPow1ToIdentity(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.000000e+00> : tensor + %0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + +// CHECK-LABEL: ConvertPow1ToIdentity +// CHECK: return %arg0 +} + +func @ConvertPow2ToSquare(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<2.000000e+00> : tensor + %0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + +// CHECK-LABEL: ConvertPow2ToSquare +// CHECK: %[[RESULT:.*]] = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> +// CHECK: return %[[RESULT]] +} + +func @ConvertIdentityGatherNdOp(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %0 = "tfl.gather_nd"(%arg0, %cst) : (tensor<4x3xf32>, tensor<4x1xi32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> + +// CHECK-LABEL: ConvertIdentityGatherNdOp +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3xf32> +} + +func @ConvertIdentityGatherNdOp3D(%arg0: tensor<4x3x4xf32>) -> tensor<4x3x4xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %0 = "tfl.gather_nd"(%arg0, %cst) : (tensor<4x3x4xf32>, tensor<4x1xi32>) -> tensor<4x3x4xf32> + return %0 : tensor<4x3x4xf32> + +// CHECK-LABEL: ConvertIdentityGatherNdOp3D +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3x4xf32>) -> tensor<4x3x4xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3x4xf32> +} + +func @ConvertIdentityScatterNd(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %shape = constant dense<[4, 3]> : tensor<2xi32> + %0 = "tfl.scatter_nd"(%cst, %arg0, %shape) : (tensor<4x1xi32>, tensor<4x3xf32>, tensor<2xi32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> + +// CHECK-LABEL: ConvertIdentityScatterNd +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3xf32> +} + +func @ReshapeAddUnknownShape(%arg0: tensor<*xf32>) -> tensor<3x4xf32> { + %cst = constant dense<[3, 4]> : tensor<2xi32> + %cst_0 = constant dense<1.000000e+00> : tensor<3x4xf32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<*xf32>, tensor<2xi32>) -> tensor<3x4xf32> + %1 = "tfl.add"(%0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +// CHECK-LABEL: ReshapeAddUnknownShape +// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]] +// CHECK: return %[[rs2]] +} + +func @FoldSumKeepDim(%arg0: tensor<8x128xf32>) -> tensor<8x1xf32> { + %cst = constant dense<1> : tensor<1xi32> + %cst_1 = constant dense<[8, 1]> : tensor<2xi32> + %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<8xf32>, tensor<2xi32>) -> tensor<8x1xf32> + return %1 : tensor<8x1xf32> + +// CHECK-LABEL: FoldSumKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.sum"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> +// CHECK: return %[[RESULT]] : tensor<8x1xf32> +} + +func @FoldReduceMinKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { + %cst = constant dense<0> : tensor<1xi32> + %cst_1 = constant dense<[1, 128]> : tensor<2xi32> + %0 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<128xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<128xf32>, tensor<2xi32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + +// CHECK-LABEL: FoldReduceMinKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_min"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: return %[[RESULT]] : tensor<1x128xf32> +} + +func @FoldReduceMaxKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { + %cst = constant dense<0> : tensor<1xi32> + %cst_1 = constant dense<[1, 128]> : tensor<2xi32> + %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<128xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<128xf32>, tensor<2xi32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + +// CHECK-LABEL: FoldReduceMaxKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: return %[[RESULT]] : tensor<1x128xf32> +} + +func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> { + %cst = constant dense<[0, 1]> : tensor<2xi32> + %cst_1 = constant dense<[1, 1]> : tensor<2xi32> + %0 = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.reshape"(%0, %cst_1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> + +// CHECK-LABEL: FoldReduceProdKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> +// CHECK: return %[[RESULT]] : tensor<1x1xf32> +} + +func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + %2 = "tfl.exp"(%1) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %3 = "tfl.sum"(%2, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %4 = "tfl.div"(%2, %3) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %4 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} + +func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.exp"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %2 = "tfl.div"(%0, %1) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %2 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithoutNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 6847cdd5874..2b871769c81 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -457,6 +457,7 @@ func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor, % // ----- module { +// expected-warning @+1 {{we cannot fuse this lstm func because the batch size is not fixed, please consider setting fixed batch size}} func @dynamic_shape_non_fuse_standard_lstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor @@ -519,3 +520,42 @@ func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> { return %0 : tensor<100xf32> } } + +// ----- + +module { +func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor +} + +// CHECK-LABEL: func @tflite_custom_nms( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} { +// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) +// CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor, tensor, tensor, tensor +// CHECK: } +} + +// ----- + +module { +// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}} +func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} + +// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}} +func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} + +// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}} +func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor +} +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 066139e179b..a0cc6cc1fdb 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s +// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -tfl-prepare-tf | FileCheck --check-prefix=LAYOUT --dump-input=always %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { @@ -53,6 +54,15 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor<256x3 // CHECK: %5 = "tf.DepthwiseConv2dNative" } +func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> { + %0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> + return %0 : tensor<256x16x30x30xf32> + + // LAYOUT-LABEL: Conv2dNCHW + // LAYOUT: "tfl.conv_2d" +} + + func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { ^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): // OK @@ -82,8 +92,8 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8 // offset - mean * scale * rsqrt(variance + epsilon) // CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) -// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -// CHECK: "tf.FusedBatchNorm"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) +// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) +// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) } func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { @@ -483,6 +493,20 @@ func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15 // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32> } +// CHECK-LABEL: @StridedSliceEllipsisMaskBeforeWithBeginAndEndMask +func @StridedSliceEllipsisMaskBeforeWithBeginAndEndMask(%arg0: tensor<4x5x4xf32>) -> tensor<4x4x4xf32> { + %cst = constant dense<[0, 1, 0]> : tensor<3xi32> + %cst_0 = constant dense<0> : tensor<3xi32> + %cst_1 = constant dense<1> : tensor<3xi32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 6 : i64, ellipsis_mask = 1 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32> + return %0 : tensor<4x4x4xf32> + + // CHECK: %[[CST:.*]] = constant dense<[0, 1, 0]> : tensor<3xi32> + // CHECK: %[[CST_0:.*]] = constant dense<0> : tensor<3xi32> + // CHECK: %[[CST_1:.*]] = constant dense<1> : tensor<3xi32> + // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST_0]], %[[CST_1]]) {begin_mask = 7 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32> +} + // CHECK-LABEL: @StridedSliceEllipsisMaskAfter func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> { %cst = constant dense<0> : tensor<2xi32> @@ -595,4 +619,51 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { // CHECK: return %[[RES]] } +func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32 +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32 +// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> +// CHECK: return [[MUL]] : tensor<3x3xi32> +} + +// CHECK-LABEL: lower_rfft_to_rfft2d +func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex> { + %0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex> + return %0: tensor<10x20x30xcomplex> + +// CHECK: %[[CST:.*]] = constant dense<-2> : tensor +// CHECK: %[[CST0:.*]] = constant dense<1> : tensor<1xi32> +// CHECK: %[[CST1:.*]] = constant dense<0> : tensor +// CHECK: %[[EXP:.*]] = "tf.ExpandDims"(%arg0, %[[CST]]) : (tensor<10x20x30xf32>, tensor) -> tensor<10x20x1x30xf32> +// CHECK: %[[CON:.*]] = "tf.ConcatV2"(%[[CST0]], %arg1, %[[CST1]]) : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> +// CHECK: %[[RFF:.*]] = "tf.RFFT2D"(%[[EXP]], %[[CON]]) : (tensor<10x20x1x30xf32>, tensor<2xi32>) -> tensor<10x20x1x30xcomplex> +// CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex>) -> tensor<10x20x30xcomplex> +} + +// CHECK-LABEL: xla_gather_to_slice +func @xla_gather_to_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor<*xf32> { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<[1, 9, 23, 768]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.XlaGather"(%arg0, %0, %1) {device = "", dimension_numbers = "\0A\04\00\01\02\03\1A\01\02", indices_are_sorted = false} : (tensor<1x9x104x768xf32>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + +// CHECK: %[[CST:.*]] = constant dense<0> : tensor<4xi64> +// CHECK: %[[CST0:.*]] = constant dense<[1, 9, 23, 768]> : tensor<4xi64> +// CHECK: %[[V0:.*]] = "tf.Slice"(%arg0, %[[CST]], %[[CST0]]) : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<*xf32> +// CHECK: return %[[V0]] : tensor<*xf32> +} + } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index d63eb481376..2feb7fedb81 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -38,6 +38,10 @@ CreateTFExecutorToControlDialectConversion(); } // namespace mlir namespace tensorflow { +namespace { +// Data layout supported by TFLite. +const char kTFLiteDataLayout[] = "NHWC"; +} // namespace void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager) { @@ -170,6 +174,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } + // Force layout supported by TFLite, this will transpose the data + // to match 'kTFLiteDataLayout' + mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; + layout_optimization_options.force_data_format = kTFLiteDataLayout; + mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager, + layout_optimization_options); // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. pass_manager->addPass( diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 414a0de0118..c158f3a8e21 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -129,6 +129,18 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( bool emit_select_tf_ops, bool emit_custom_ops, const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, mlir::PassManager* pass_manager) { + // Register a warning handler only log to std out. + mlir::ScopedDiagnosticHandler s( + module.getContext(), [](mlir::Diagnostic& diag) { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) { + for (auto& note : diag.getNotes()) { + std::cout << note.str() << "\n"; + LOG(WARNING) << note.str() << "\n"; + } + } + return mlir::failure(); + }); + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 47cfaecd3fb..322da815a47 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -27,6 +27,9 @@ def NonOpaqueElementsAttr : ElementsAttrBase< def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; +def Int64ElementsAttr : ElementsAttrBase< + CPred<"$_self.cast().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. class ExtractI32At : NativeCodeCall< @@ -50,6 +53,10 @@ def ExtractSingleElementAsInteger : NativeCodeCall< def ExtractSingleElementAsInt32 : NativeCodeCall< "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; +// Converts tensor with int64 to int32. +def CreateCastToInt32 : NativeCodeCall< + "CreateCastToInt32($0, $_loc, $_builder)">; + // Checks whether the given operation has static shapes and same shapes of all inputs. def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; def HasSameStaticShapes : Constraint; @@ -149,6 +156,7 @@ def LegalizeMaxPool2D : Pat< IsIntList1XY1:$ksize, IsIntList1XY1:$strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$format), (TFL_MaxPool2DOp $value, /*padding=*/$padding, @@ -207,8 +215,14 @@ def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0), def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; + +def LegalizeTransposeInt64 : Pat< + (TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)), + (TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>; + def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>; + def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 7a16e475ce3..6f7f3b88471 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Threading.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -45,8 +46,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" @@ -64,7 +67,6 @@ namespace TFL { // The actual LegalizeTF Pass. namespace { -using xla::Status; using xla::StatusOr; constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; @@ -73,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. class LegalizeTF : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LegalizeTF() = default; LegalizeTF(const LegalizeTF&) {} @@ -111,6 +117,17 @@ bool HasSameStaticShapes(Operation* op) { return true; } +// Util that casts 'val' to Int32 by adding a cast Op. +Value CreateCastToInt32(Attribute val, Location loc, + PatternRewriter& rewriter) { + auto shape = val.getType().dyn_cast().getShape(); + IntegerType new_ele_type = rewriter.getIntegerType(32); + ShapedType new_type = RankedTensorType::get(shape, new_ele_type); + return rewriter.create(loc, new_type, + rewriter.create(loc, val), + rewriter.getBoolAttr(false)); +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" #define DECL_CONVERT_OP(tf_op) \ @@ -137,7 +154,6 @@ DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Reciprocal); DECL_CONVERT_OP(RandomUniform); -DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -154,9 +170,8 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( tensorflow::random::PhiloxRandom, float> Distribution; - tensorflow::random::PhiloxRandom generator( - random_uniform_op.seed().getSExtValue(), - random_uniform_op.seed2().getSExtValue()); + tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(), + random_uniform_op.seed2()); Distribution dist; size_t num_elements = 0; if (auto output_type = @@ -227,26 +242,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( return success(); } -// The following is effectively: -// def : Pat< -// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a, -// ConstBoolAttrTrue:$transpose_b), -// (TFL_FullyConnectedOp:$__0 $a, $b, -// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); - if (tf_matmul_op.transpose_a()) return failure(); - if (!tf_matmul_op.transpose_b()) return failure(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto transpose = [&](Value input) -> std::pair { + RankedTensorType type = + input.getType().dyn_cast_or_null(); + if (!type || type.getRank() != 2) return {failure(), nullptr}; + + auto permute_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0}); + auto permute = rewriter.create( + op->getLoc(), permute_attr.getType(), permute_attr); + llvm::SmallVector new_shape{type.getShape()[1], + type.getShape()[0]}; + auto output = rewriter.create( + op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()), + input, permute); + return {success(), output}; + }; + + // TODO(jpienaar): Remove once handled via dailect conversion. + if (tf_matmul_op.transpose_a()) { + LogicalResult result = success(); + std::tie(result, lhs) = transpose(lhs); + if (failed(result)) return failure(); + } + if (!tf_matmul_op.transpose_b()) { + LogicalResult result = success(); + std::tie(result, rhs) = transpose(rhs); + if (failed(result)) return failure(); + } Type output_type = tf_matmul_op.getResult().getType(); - // TODO(jpienaar): Follow up post shuffle discussion. auto no_input = rewriter.create( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); auto fc_op = rewriter.create( - op->getLoc(), ArrayRef{output_type}, op->getOperand(0), - op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), - rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); + op->getLoc(), ArrayRef{output_type}, lhs, rhs, no_input, + rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"), + rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {fc_op.getResult(0)}); return success(); } @@ -259,7 +295,7 @@ LogicalResult ConvertTFPackOp::matchAndRewrite( auto output_type = tf_pack_op.output().getType(); auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue()); + auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis()); rewriter.replaceOpWithNewOp(op, output_type, values, values_count, axis); @@ -356,27 +392,22 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), tf_strided_slice_op.begin(), tf_strided_slice_op.end(), tf_strided_slice_op.strides(), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.begin_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.end_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.ellipsis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.new_axis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.shrink_axis_mask().getSExtValue())); + rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); return success(); } int num_input_dims = ranked_input_type.getRank(); // Pad `begin` array with zero values and update the `begin_mask`. SmallVector begin_pad_val(num_input_dims, 0); - int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue(); + int begin_mask = tf_strided_slice_op.begin_mask(); Value padded_begin = PadStridedSliceAttributeArray( op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask); // Pad `end` array with `input_shape` and update the `end_mask`. - int end_mask = tf_strided_slice_op.end_mask().getSExtValue(); + int end_mask = tf_strided_slice_op.end_mask(); auto input_shape = ranked_input_type.getShape(); SmallVector end_pad_val(input_shape.begin(), input_shape.end()); Value padded_end = PadStridedSliceAttributeArray( @@ -390,12 +421,9 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( padded_begin, padded_end, padded_strides, rewriter.getI32IntegerAttr(begin_mask), rewriter.getI32IntegerAttr(end_mask), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.ellipsis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.new_axis_mask().getSExtValue()), - rewriter.getI32IntegerAttr( - tf_strided_slice_op.shrink_axis_mask().getSExtValue())); + rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), + rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); return success(); } @@ -406,7 +434,7 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite( auto input = tf_unpack_op.value(); auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue()); + auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis()); rewriter.replaceOpWithNewOp(op, tf_unpack_op.output().getTypes(), input, num, axis); @@ -483,89 +511,6 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite( return success(); } -StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, - Location loc, - ShapedType shaped_type, - int value) { - Type element_type = shaped_type.getElementType(); - ShapedType scalar_type = RankedTensorType::get({}, element_type); - Attribute attr; - switch (element_type.getKind()) { - case mlir::StandardTypes::F16: { - auto floatType = mlir::FloatType::getF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::BF16: { - auto floatType = mlir::FloatType::getBF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::F32: { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); - break; - } - case mlir::StandardTypes::Complex: { - auto etype = element_type.cast().getElementType(); - if (etype.isF32()) { - auto dialect = etype.getContext()->getRegisteredDialect("tf"); - tensorflow::TensorProto repr; - repr.set_dtype(tensorflow::DT_COMPLEX64); - - tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); - shape->set_unknown_rank(false); - shape->add_dim()->set_size(int64_t{1}); - std::string content; - auto complex_value = - std::complex(static_cast(value), 0.0f); - content.assign(reinterpret_cast(&complex_value), - sizeof(complex_value)); - repr.set_tensor_content(content); - std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - - attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); - break; - } - return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = element_type.cast(); - switch (itype.getWidth()) { - case 8: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 16: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 32: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 64: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - default: - return Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - break; - } - default: - return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); - } - return rewriter->create(loc, scalar_type, attr); -} - LogicalResult ConvertTFReciprocalOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reciprocal_op = cast(op); @@ -586,31 +531,6 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite( return success(); } -LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto tf_broadcast_to_op = cast(op); - auto element_type = tf_broadcast_to_op.input().getType().cast(); - auto output_type = tf_broadcast_to_op.output().getType(); - - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - auto tfl_fill_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.shape(), - status_or_const_op.ValueOrDie()); - - StringAttr fused_activation_function = - StringAttr::get("NONE", rewriter.getContext()); - - rewriter.replaceOpWithNewOp( - op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, - fused_activation_function); - return success(); -} - // Legalize unidirectional sequence lstm. struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) @@ -751,7 +671,7 @@ void LegalizeTF::runOnFunction() { ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp, - ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); + ConvertTFRandomUniformOp>(context); // Ophint python converter converted tf node pattern. patterns.insert> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void RunOnFunction(FuncOp func); void runOnOperation() override { @@ -60,8 +64,8 @@ void RunOnWhile(TF::WhileOp while_op) { // Mark old function as private so that it can be DCE'd if not called. func.setVisibility(SymbolTable::Visibility::Private); }; - create_region_with_call(while_op.cond_func(), new_op.cond()); - create_region_with_call(while_op.body_func(), new_op.body()); + create_region_with_call(while_op.cond_function(), new_op.cond()); + create_region_with_call(while_op.body_function(), new_op.body()); op->replaceAllUsesWith(new_op.getResults()); op->erase(); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index edddc7751ab..54bfc5fa3a7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -714,7 +714,7 @@ struct ConvertTensorListStack RankedTensorType shape_type = RankedTensorType::get({-1}, rewriter.getIntegerType(32)); auto new_shape = rewriter.create(loc, shape_type, input); - SmallVector output_shape = {op.num_elements().getSExtValue()}; + SmallVector output_shape(/*Size=*/1, op.num_elements()); for (const auto &dim : dense_elem_attr.getIntValues()) output_shape.push_back(dim.getSExtValue()); RankedTensorType result_type = @@ -749,7 +749,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { // Changes the function type of `cond_func` and `body_func` for the given While // op. LogicalResult UpdateFunctionTypes(TF::WhileOp op) { - for (FuncOp func : {op.cond_func(), op.body_func()}) { + for (FuncOp func : {op.cond_function(), op.body_function()}) { if (!func) continue; FunctionType func_type = func.getType(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 6de6187d81a..d28ee4b31fa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -37,8 +38,10 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -103,7 +106,8 @@ bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { bool IsTailOfShape(Type type1, Type type2) { auto tail_type = type1.dyn_cast(); auto full_type = type2.dyn_cast(); - if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank()) + if (!tail_type || !full_type || !tail_type.hasRank() || + !full_type.hasRank() || tail_type.getRank() > full_type.getRank()) return false; auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend(); auto i2 = full_type.getShape().rbegin(); @@ -160,6 +164,31 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, return false; } +// Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value +// of `indices` are from 0 to n-1, the output tensor are identical to the +// `params`. +bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, + DenseIntElementsAttr indices) { + auto params_type = params.getType().dyn_cast(); + auto indices_type = indices.getType().dyn_cast(); + // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D + // `indices` means it gets the first row of `params`. As long as indices + // iterate the first row of `params`, the output is identical to input. + if (!params_type || !indices_type || indices_type.getRank() != 2 || + indices_type.getDimSize(0) != params_type.getDimSize(0) || + indices_type.getDimSize(1) != 1) + return false; + + // Checks the value in `indices` is from 0 to n-1. + int cur_value = 0; + for (const auto &v : indices.getValues()) { + if (v.getSExtValue() != cur_value) return false; + ++cur_value; + } + + return true; +} + // Expand Attribute 'a' to 4D with all 1s except 1 dimension. // Which dimension depends on 'is_depthwise' is true or false. ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) { @@ -219,6 +248,38 @@ static Type GetShapeStrippedType(TypeAttr type_attr) { } } +// Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in +// the specified `shape` and `false` otherwise. +static bool ShapeMatchesReduceWithKeepAxes(Value input, + const mlir::Attribute &axes, + const mlir::Attribute &shape) { + RankedTensorType type = input.getType().dyn_cast_or_null(); + if (!type) return false; + + DenseIntElementsAttr axes_attr = + axes.dyn_cast_or_null(); + DenseIntElementsAttr shape_attr = + shape.dyn_cast_or_null(); + if (!axes_attr || !shape_attr) return false; + + if (shape_attr.getNumElements() != type.getRank()) return false; + + llvm::SmallSet axes_set; + for (auto a : axes_attr.getIntValues()) { + axes_set.insert(a.getZExtValue()); + } + + auto type_shape = type.getShape(); + for (uint64_t i = 0; i < type.getRank(); ++i) { + if (axes_set.contains(i)) { + if (shape_attr.getValue({i}) != 1) return false; + } else { + if (shape_attr.getValue({i}) != type_shape[i]) return false; + } + } + return true; +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" // Fuse Add with proceeding FullyConnected. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 2311ae0668c..f1ea837446b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -83,8 +83,8 @@ class FoldIfOp : public OpRewritePattern { if (!llvm::hasSingleElement(parent_op)) return failure(); // Find the then and else branch functions. - FuncOp then_func = op.then_func(); - FuncOp else_func = op.else_func(); + FuncOp then_func = op.then_function(); + FuncOp else_func = op.else_function(); // If the If has no uses and its functions are side-effect free, then // remove. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 83a09e9dd2b..8243ed2a620 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -21,8 +21,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/lite/utils/utils.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +// Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; + CPred<"$_self.isa() && $_self.cast().getType().getElementType().isF32()">, + "float constant tensor">; + +// Checks if the param passed is of NoneType. +def IsNoneType : Constraint()">>; def ExtractSingleElementAsFloat : NativeCodeCall< "ExtractSingleElementAsFloat($_self.cast())">; @@ -93,6 +98,29 @@ multiclass FuseBinaryOpToPrecedingAffine { $multiplier), [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), (HasOneUse $output)]>; + def FuseBinaryOpWithTransposeConv#binaryOp : Pat< + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (ConstantOp F32ElementsAttr:$bias), $padding, + $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, $weights, $inputs, + (binaryOp (ConstantOp $bias), + (ConstantOp $value), TFL_AF_None), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (HasOneUse $output)]>; + // Fuse for TransposeConv with no bias + def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (ConstantOp $bias), $padding, + $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, $weights, $inputs, + (ConstantOp $value), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (IsNoneType $bias), + (HasOneUse $output)]>; } foreach binaryOp = [TFL_AddOp, TFL_SubOp] in defm : FuseBinaryOpToPrecedingAffine; @@ -146,6 +174,39 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), (HasOneUse $conv_output)]>; + def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< + (BinaryOp (TFL_TransposeConvOp:$output $output_shape, + (ConstantOp F32ElementsAttr:$weights), $input, + (ConstantOp F32ElementsAttr:$bias), + $padding, $stride_h, $stride_w), + (ConstantOp $value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, + (BinaryOp (ConstantOp $weights), + (ConstantOp (ExpandTo4DForConv $value)), + TFL_AF_None), + $input, + (BinaryOp (ConstantOp $bias), + (ConstantOp $value), + TFL_AF_None), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (HasOneUse $output)]>; + def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< + (BinaryOp (TFL_TransposeConvOp:$output $output_shape, + (ConstantOp F32ElementsAttr:$weights), $input, + (ConstantOp $bias), + $padding, $stride_h, $stride_w), + (ConstantOp $value), TFL_AF_None), + (TFL_TransposeConvOp $output_shape, + (BinaryOp (ConstantOp $weights), + (ConstantOp (ExpandTo4DForConv $value)), + TFL_AF_None), + $input, + (ConstantOp $bias), + $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), + (IsNoneType $bias), + (HasOneUse $output)]>; } foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in @@ -508,3 +569,81 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { def OptimizeReluSquaredDifference : Pat< (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)), (TFL_SquaredDifferenceOp $l, $r)>; + +// Optimize X^1 o X +def OptimizePow1ToIdentity : Pat< + (TFL_PowOp $input, + (ConstantOp ConstantAttr, "1.0f">)), + (replaceWithValue $input)>; + +// Optimize X^2 to X*X +def OptimizePow2ToSquare : Pat< + (TFL_PowOp $input, + (ConstantOp ConstantAttr, "2.0f">)), + (TFL_MulOp $input, $input, TFL_AF_None)>; + +def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint())">>; + +def OptimizeIdentityGatherNdOp : Pat< + (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)), + (replaceWithValue $params), + [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; + +def OptimizeIdentityScatterNdOp : Pat< + (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored), + (replaceWithValue $params), + [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; + +def ShapeMatchesReduceWithKeepAxes : Constraint>; + +// Fold reshapes re-inserting reduced dimensions into the results of a reduction +// with `keep_dims=false` by chaning it to one using `keep_dims=true`. +foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp, + TFL_SumOp] in { + def FoldReshapeTo#ReduceOp : Pat< + (TFL_ReshapeOp + (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrFalse), + (ConstantOp I32ElementsAttr: $shape)), + (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue), + [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape), + (HasOneUse $reduce)]>; +} + + +def IsSame : Constraint>; +def HasTwoUse : Constraint>; +def AxesIsLastDimension : Constraint().getNumElements() == 1 && " + "$0.cast().getValue({0}) == " + "$1.getType().cast().getRank() - 1">>; + +// Convert exp(x)/sum(exp(x)) into softmax. +def OptimizeToSoftmax : Pat< + (TFL_DivOp (TFL_ExpOp:$exp $input), + (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), TFL_AF_None), + (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), + [(IsSame $exp, $sum_input), + (AxesIsLastDimension $axes, $sum_input), + (HasTwoUse $exp), + (HasOneUse $sum)]>; + +// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals +// with the max normalization. +def FoldNormalizationIntoSoftmax : Pat< + (TFL_SoftmaxOp + (TFL_SubOp:$sub $input, + (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), + TFL_AF_None), + $beta), + (TFL_SoftmaxOp $input, $beta), + [(IsSame $input, $max_input), + (AxesIsLastDimension $axes, $max_input), + (HasOneUse $sub), + (HasOneUse $max)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 3be6246c0dd..172ce59ddd4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -59,6 +60,7 @@ namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; constexpr char kTFTextAPIPrefix[] = "tftext:"; +constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess"; constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; using mlir::TF::FuncAttr; @@ -99,59 +101,6 @@ class ConvertEmbeddedLookupFunc { FuncOp func_; }; -// Abstracts the conversion of the padded NMS composite function. -class ConvertNMSPaddedFunc { - public: - explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {} - - void RewriteFunc() { - func_.setAttr(kTFImplements, - StringAttr::get(kTfNMSPadded, func_.getContext())); - Value boxes = func_.getArgument(0); - Value scores = func_.getArgument(1); - Value max_output_size = func_.getArgument(2); - Value iou_threshold = func_.getArgument(3); - Value score_threshold = func_.getArgument(4); - auto output_type0 = func_.getType().getResult(0); - auto output_type1 = func_.getType().getResult(1); - - OpBuilder builder(func_.getBody()); - auto op = builder.create( - func_.getLoc(), output_type0, output_type1, boxes, scores, - max_output_size, iou_threshold, score_threshold); - - builder.create(func_.getLoc(), op.getResults()); - } - - LogicalResult VerifySignature() { - // Verify high-level function signature. - // Relevant argument characteristics are checked by the TFL op definition. - if (func_.getNumArguments() < 5) { - return func_.emitError() - << "Invalid number of arguments to " - "non_max_suppression_padded_v2 (need atleast 5): " - << func_.getNumArguments(); - } - if (func_.getType().getNumResults() != 2) { - return func_.emitError() << "Invalid number of results from " - "non_max_suppression_padded_v2 (need 2): " - << func_.getType().getNumResults(); - } - // The TFLite fused op does not support batching yet. - // TODO(b/158709815): Add support for batches with padded NMS. - auto boxes_type = - func_.getArgument(0).getType().dyn_cast(); - if (!boxes_type.hasRank() || boxes_type.getRank() != 2) { - return func_.emitError() << "TFLite does not support batched input for " - "non_max_suppression_padded"; - } - return success(); - } - - private: - FuncOp func_; -}; - // This pass uses mechanisms listed in RFC: // https://github.com/tensorflow/community/pull/113 // It prepares composite functions that are attributed to indicate @@ -161,6 +110,10 @@ class ConvertNMSPaddedFunc { class PrepareCompositeFunctionsPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: explicit PrepareCompositeFunctionsPass() {} @@ -219,6 +172,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( if (failed(ConvertTFTextAPI(func, api_name, attr))) { return signalPassFailure(); } + } else if (api_name == kCustomSSDPostprocessing) { + ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr); + if (failed(convert_ssd_postprocess.VerifySignature()) || + failed(convert_ssd_postprocess.RewriteFunc())) { + return signalPassFailure(); + } } } @@ -261,7 +220,15 @@ LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) { for (int i = 1; i < 3; ++i) { auto input = lstm_func.getArgument(i); auto input_type = input.getType().dyn_cast_or_null(); - if (!input_type || !input_type.hasStaticShape()) return failure(); + if (!input_type || !input_type.hasStaticShape()) { + lstm_func.emitWarning( + "we cannot fuse this lstm func because the batch size is not fixed, " + "please consider setting fixed batch size like " + "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/" + "lite/examples/experimental_new_converter/" + "Keras_LSTM_fusion_Codelab.ipynb"); + return failure(); + } } return success(); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index f5b252773f6..326b6b23398 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -40,7 +40,7 @@ def : Pat< (TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)), (TF_SubOp $beta, (TF_MulOp $m, $mul)))>; -// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic +// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic // operations. Specifically, performs the following calculation: // // (x - mean) * scale / sqrt(variance + epsilon) + offset @@ -50,28 +50,6 @@ def : Pat< // (x - mean) * scale / sqrt(variance + epsilon) + offset, // is then to compute // (x * multiplier) + (offset - mean * multiplier). -def : Pattern< - (TF_FusedBatchNormOp:$root - $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $exponential_avg_factor, - $data_format, FalseBoolAttr:$is_training), - [(TF_AddOp - (TF_MulOp - $x, - (TF_MulOp:$multiplier - $scale, - (TF_RsqrtOp - (TF_AddOp $variance, - (TF_ConstOp $epsilon))))), - (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), - // We already guaranteed that the last four results has no use so it does - // not matter what value we provide here for replacement. - /*batch_mean=*/(replaceWithValue $x), - /*batch_variance=*/(replaceWithValue $x), - /*reserve_space_1=*/(replaceWithValue $x), - /*reserve_space_2=*/(replaceWithValue $x)], - [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), - (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; def : Pattern< (TF_FusedBatchNormV3Op:$root diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 9a27d0de62a..783f21fce21 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -68,6 +69,11 @@ namespace { // training quantization simpler. class PrepareQuantizePass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: // Constructor used by the PassRegistration and enforce uint8 quantization. // This is only used by test. @@ -122,6 +128,10 @@ class PrepareQuantizePass // the best quantization practise. This also fixes some simple violations. void SanityCheckAndAdjustment(FuncOp func); + // Whether the func contains Quantize ops. This is used to determine whether + // to use the quantization parameters from the fixed output range property. + bool ContainsQuantizeOps(FuncOp func); + QuantizationSpecs quant_specs_; }; @@ -285,6 +295,13 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) { }); } +bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) { + for (const auto& op : func.getOps()) { + if (llvm::isa(op)) return true; + } + return false; +} + using PrepareQuantStats = quant::ConvertStatsToQDQs; @@ -309,6 +326,7 @@ void PrepareQuantizePass::runOnFunction() { OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); int bit_width = quant_specs_.GetQuantizationTypeWidth(); + bool enforce_fixed_output_range = ContainsQuantizeOps(func); if (is_signed) { patterns.insert>(ctx); // Convert quant stats to int8 quantization parameters. @@ -327,7 +345,8 @@ void PrepareQuantizePass::runOnFunction() { // values (tensors). ApplyQuantizationParamsPropagation( func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, - GetOpQuantSpec); + GetOpQuantSpec, + enforce_fixed_output_range || quant_specs_.post_training_quantization); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 62688937d7e..2b118d0b810 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -57,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" @@ -78,13 +80,23 @@ namespace { // Prepare TF operations in functions for subsequent legalization. class PrepareTFPass : public PassWrapper { public: - explicit PrepareTFPass() : unfold_batch_matmul_(true) {} - explicit PrepareTFPass(bool unfold_batch_matmul) - : unfold_batch_matmul_(unfold_batch_matmul) {} + PrepareTFPass() = default; + PrepareTFPass(const PrepareTFPass &) {} + explicit PrepareTFPass(bool unfold_batch_matmul) { + unfold_batch_matmul_ = unfold_batch_matmul; + } void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + private: - bool unfold_batch_matmul_; + Option unfold_batch_matmul_{ + *this, "tfl-unfold-batch-matmul", + llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."), + llvm::cl::init(true)}; }; template @@ -203,9 +215,8 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. - rewriter.setInsertionPointAfter(tf_op); - IntegerAttr num_bits = - rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue()); + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( @@ -526,8 +537,8 @@ struct ConvertTFStridedSlice : public RewritePattern { loc, new_output_type, original_input, shape); // Replace the original strided_slice. - llvm::APInt new_begin_mask = strided_slice_op.begin_mask(); - llvm::APInt new_end_mask = strided_slice_op.end_mask(); + uint64_t new_begin_mask = strided_slice_op.begin_mask(); + uint64_t new_end_mask = strided_slice_op.end_mask(); // Since we expand the dims, we need to apply them to the begin_mask & // end_mask. new_begin_mask |= strided_slice_op.new_axis_mask(); @@ -590,8 +601,8 @@ struct ConvertTFStridedSlice : public RewritePattern { const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1; - int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue(); - int64_t end_mask = strided_slice_op.end_mask().getSExtValue(); + int64_t begin_mask = strided_slice_op.begin_mask(); + int64_t end_mask = strided_slice_op.end_mask(); int64_t new_begin_mask = 0; int64_t new_end_mask = 0; @@ -627,13 +638,16 @@ struct ConvertTFStridedSlice : public RewritePattern { ++index; // After the ellipsis. - for (; index < begin_shape[0]; ++index) { + for (; index < begin_shape[0];) { padded_begin.push_back(begin_dense_elem_attr.getValue(index)); padded_end.push_back(end_dense_elem_attr.getValue(index)); padded_stride.push_back(stride_dense_elem_attr.getValue(index)); if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index); if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index); + + ++index; + ++new_index; } auto attribute_type = rewriter.getIntegerType(64); @@ -669,16 +683,16 @@ struct ConvertTFStridedSlice : public RewritePattern { // TODO(renjieliu): Consider expand the transformation for shrink mask as // well. - if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure(); + if (strided_slice_op.shrink_axis_mask()) return failure(); // Handle new axis mask. - uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue(); + uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); if (new_axis_mask != 0) { return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter); } // Handle ellipsis mask. - uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue(); + uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask(); if (ellipsis_mask != 0) { return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter); } @@ -686,6 +700,71 @@ struct ConvertTFStridedSlice : public RewritePattern { } }; +struct ConvertTFBroadcastTo : public RewritePattern { + explicit ConvertTFBroadcastTo(MLIRContext *context) + : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto tf_broadcast_to_op = cast(op); + auto input_type = tf_broadcast_to_op.input().getType().cast(); + auto output_type = tf_broadcast_to_op.output().getType().cast(); + auto shape_type = tf_broadcast_to_op.shape().getType().cast(); + Type element_type = input_type.getElementType(); + + // Allow lowering when low dimension inputs are given and its type is F32 or + // I32. + if (!((output_type.hasRank() && output_type.getRank() <= 5) || + (shape_type.hasStaticShape() && shape_type.getRank() == 1 && + shape_type.getDimSize(0) <= 5))) + return failure(); + + if (!(element_type.isa() || + element_type.isInteger(32))) + return failure(); + + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + auto tf_fill_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.shape(), + status_or_const_op.ValueOrDie()); + + auto mul_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op); + rewriter.replaceOp(op, mul_op.getResult()); + return success(); + } +}; + +struct ConvertFusedBatchNorm : public OpRewritePattern { + explicit ConvertFusedBatchNorm(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, + PatternRewriter &rewriter) const override { + auto new_result_types = + llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); + // reserve_space_3 + new_result_types.push_back( + UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); + + OperationState new_state(tf_fused_batch_norm_op.getLoc(), + TF::FusedBatchNormV3Op::getOperationName(), + tf_fused_batch_norm_op.getOperands(), + new_result_types, + tf_fused_batch_norm_op.getAttrs()); + Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); + + rewriter.replaceOp(tf_fused_batch_norm_op, + tf_fused_batch_norm_op_v3->getResults().drop_back()); + return success(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" // Returns success if all the operations in the `op`'s regions including `op` @@ -711,14 +790,113 @@ LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) { target.addLegalOp(); target.addLegalOp(); target.addIllegalOp(); + target.addIllegalOp(); OwningRewritePatternList patterns; mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns); + mhlo::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateLegalizeHloToTfPatterns(&patterns, context); + mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); return applyPartialConversion(func, target, patterns); } +// Convert rfft to rfft2d. +// The transformation pattern looks like below: +// +// input fft_len +// \ / +// rfft +// +// || +// \/ +// +// input fft_len +// \ / +// expand_dim concat with [1] at the front +// \ / +// rfft_2d +// | +// squeeze +struct ConvertRfftToRfft2d : public RewritePattern { + explicit ConvertRfftToRfft2d(MLIRContext *context) + : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto rfft_op = dyn_cast(op); + + auto input = rfft_op.input(); + auto input_type = input.getType().dyn_cast_or_null(); + if (!input_type) return failure(); + auto fft_len = rfft_op.fft_length(); + auto fft_len_type = fft_len.getType().dyn_cast_or_null(); + if (!fft_len_type) return failure(); + + auto output_type = + rfft_op.getResult().getType().dyn_cast_or_null(); + if (!output_type) return failure(); + + // Expanded inputs. + // Insert at -2 location. + auto one_ele_type = + mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32)); + auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), + one_ele_type, -2); + + SmallVector expanded_input_shape; + SmallVector expanded_output_shape; + int expanded_rank = input_type.getRank() + 1; + int r = 0; + for (int i = 0; i < expanded_rank; ++i) { + if (i == expanded_rank - 2) { + expanded_input_shape.push_back(1); + expanded_output_shape.push_back(1); + } else { + expanded_input_shape.push_back(input_type.getDimSize(r)); + expanded_output_shape.push_back(output_type.getDimSize(r)); + r++; + } + } + + auto expaned_input_type = mlir::RankedTensorType::get( + expanded_input_shape, input_type.getElementType()); + TF::ExpandDimsOp expanded_input = rewriter.create( + rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult()); + + // Expanded fft_len. + auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1}); + + auto one = rewriter.create(rfft_op.getLoc(), one_attr); + + auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), + one_ele_type, 0); + + auto expanded_fft_len_type = + mlir::RankedTensorType::get({2}, fft_len_type.getElementType()); + + TF::ConcatV2Op expanded_fft_len = rewriter.create( + rfft_op.getLoc(), expanded_fft_len_type, + SmallVector({one.getResult(), fft_len}), zero->getResult()); + + // Insert the rfft_2d. + auto rfft2d_out_type = mlir::RankedTensorType::get( + expanded_output_shape, output_type.getElementType()); + TF::RFFT2DOp rfft2d = rewriter.create( + rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(), + expanded_fft_len.getResult()); + + // Insert the squeeze op. + auto squeeze_dim = rewriter.getI64ArrayAttr({-2}); + TF::SqueezeOp squeeze = rewriter.create( + rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim); + + rewriter.replaceOp(op, squeeze.getResult()); + + return success(); + } +}; + void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); @@ -751,6 +929,8 @@ void PrepareTFPass::runOnFunction() { // replaced with a single Conv op with dilation parameter. patterns.insert, ConvertTFDilatedConvOp>(ctx); + + patterns.insert(ctx); TFL::populateWithGenerated(ctx, &patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. @@ -767,8 +947,9 @@ void PrepareTFPass::runOnFunction() { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); } - patterns.insert(ctx); + patterns.insert(ctx); applyPatternsAndFoldGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 3342981b75f..56b38ec58d8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -80,7 +80,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { // The basic block arguments correspond to values that are loop carried, while // all those post are loop independent. Initialize extern_values with while_op // not loop carried operands. - auto num_loop_carried = while_op.cond().front().getNumArguments(); + auto num_loop_carried = while_op.cond().getNumArguments(); auto not_carried_operands = while_op.getOperands().drop_front(num_loop_carried); extern_values.insert(not_carried_operands.begin(), @@ -124,8 +124,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { // Collect new types. SmallVector types; types.reserve(extra_operands.size() + while_op.getNumOperands()); - for (BlockArgument ba : while_op.cond().front().getArguments()) - types.push_back(ba.getType()); + for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type); for (Value operand : extern_values) types.push_back(operand.getType()); // Create outline function from region. Optional pass extra arguments through diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc new file mode 100644 index 00000000000..b32da24d00f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace TFL { + +stream_executor::port::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + Type element_type = shaped_type.getElementType(); + ShapedType scalar_type = RankedTensorType::get({}, element_type); + Attribute attr; + if (element_type.isF16()) { + auto floatType = mlir::FloatType::getF16(element_type.getContext()); + auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + } else if (element_type.isBF16()) { + auto floatType = mlir::FloatType::getBF16(element_type.getContext()); + auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + } else if (element_type.isF32()) { + attr = + DenseElementsAttr::get(scalar_type, static_cast(value)); + } else if (auto complex_type = element_type.dyn_cast()) { + auto etype = complex_type.getElementType(); + if (etype.isF32()) { + auto dialect = etype.getContext()->getLoadedDialect("tf"); + tensorflow::TensorProto repr; + repr.set_dtype(tensorflow::DT_COMPLEX64); + + tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); + shape->set_unknown_rank(false); + shape->add_dim()->set_size(int64_t{1}); + std::string content; + auto complex_value = std::complex(static_cast(value), 0.0f); + content.assign(reinterpret_cast(&complex_value), + sizeof(complex_value)); + repr.set_tensor_content(content); + std::string mangled = tensorflow::mangling_util::MangleTensor(repr); + + attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); + } else { + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + } else if (auto itype = element_type.dyn_cast()) { + switch (itype.getWidth()) { + case 8: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 16: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 32: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 64: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + default: + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + } else { + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + return rewriter->create(loc, scalar_type, attr); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h new file mode 100644 index 00000000000..5c348021b5e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace mlir { +namespace TFL { + +// Returns a Constant op with a single value. +stream_executor::port::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + +} // namespace TFL +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 081ba7ac6e7..f26689fac5e 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test { LstmUtilsTest() {} void SetUp() override { - RegisterDialects(); context_ = std::make_unique(); + context_->loadDialect(); builder_ = std::unique_ptr(new Builder(context_.get())); fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false); fused_lstm_func_cifg_ = @@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test { builder_.reset(); } - void RegisterDialects() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - } - FuncOp fused_lstm_func_; FuncOp fused_lstm_func_cifg_; FuncOp fused_ln_lstm_func_; diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc new file mode 100644 index 00000000000..e462d4f38b0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +// TODO(b/162842801): Consolidate all util definitions of kTFImplements. +constexpr char kTFImplements[] = "tf._implements"; +constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess"; +constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; + +inline OpaqueElementsAttr CustomOption(OpBuilder* builder, + const std::string& content) { + ShapedType type = RankedTensorType::get( + {static_cast(content.size())}, builder->getIntegerType(8)); + return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"), + type, + StringRef(content.data(), content.size())); +} + +} // namespace + +void ConvertNMSPaddedFunc::RewriteFunc() { + func_.setAttr(kTFImplements, + StringAttr::get(kTfNMSPadded, func_.getContext())); + Value boxes = func_.getArgument(0); + Value scores = func_.getArgument(1); + Value max_output_size = func_.getArgument(2); + Value iou_threshold = func_.getArgument(3); + Value score_threshold = func_.getArgument(4); + auto output_type0 = func_.getType().getResult(0); + auto output_type1 = func_.getType().getResult(1); + + OpBuilder builder(func_.getBody()); + auto op = builder.create( + func_.getLoc(), output_type0, output_type1, boxes, scores, + max_output_size, iou_threshold, score_threshold); + + builder.create(func_.getLoc(), op.getResults()); +} + +LogicalResult ConvertNMSPaddedFunc::VerifySignature() { + // Verify high-level function signature. + // Relevant argument characteristics are checked by the TFL op definition. + if (func_.getNumArguments() < 5) { + return func_.emitError() + << "Invalid number of arguments to " + "non_max_suppression_padded_v2 (need atleast 5): " + << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 2) { + return func_.emitError() << "Invalid number of results from " + "non_max_suppression_padded_v2 (need 2): " + << func_.getType().getNumResults(); + } + // The TFLite fused op does not support batching yet. + // TODO(b/158709815): Add support for batches with padded NMS. + auto boxes_type = func_.getArgument(0).getType().dyn_cast(); + if (!boxes_type.hasRank() || boxes_type.getRank() != 2) { + return func_.emitError() << "TFLite does not support batched input for " + "non_max_suppression_padded"; + } + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() { + func_.eraseBody(); + func_.addEntryBlock(); + func_.setAttr(kTFImplements, + StringAttr::get(kCustomSSDPostprocessing, func_.getContext())); + + OpBuilder builder(func_.getBody()); + std::string custom_option_buffer; + if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(), + custom_option_buffer))) { + return failure(); + } + auto op = builder.create( + func_.getLoc(), func_.getType().getResults(), func_.getArguments(), + kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer)); + builder.create(func_.getLoc(), op.getResults()); + + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( + FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) { + flexbuffers::Builder fbb; + size_t start_map = fbb.StartMap(); + + if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) || + failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) || + failed(AddIntAttr(func, attrs, "num_classes", &fbb)) || + failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) || + failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) || + failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) || + failed(AddFloatAttr(func, attrs, "w_scale", &fbb))) + return failure(); + auto use_regular_nms = + attrs.get("use_regular_nms").dyn_cast_or_null(); + if (!use_regular_nms) { + return func.emitError() + << "use_regular_nms attribute is not set or not a bool"; + } + fbb.Int("use_regular_nms", use_regular_nms.getValue()); + + fbb.EndMap(start_map); + fbb.Finish(); + custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( + FuncOp func, DictionaryAttr attrs, const std::string& attribute, + flexbuffers::Builder* builder) { + auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + if (!int_attr) { + return func.emitError() + << attribute.c_str() << " attribute is not set or not an integer"; + } + builder->Int(attribute.c_str(), int_attr.getInt()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( + FuncOp func, DictionaryAttr attrs, const std::string& attribute, + flexbuffers::Builder* builder) { + auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + if (!float_attr) { + return func.emitError() + << attribute.c_str() << " attribute is not set or not a float"; + } + builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat()); + return success(); +} + +LogicalResult ConvertSSDPostProcessFunc::VerifySignature() { + // Verify high-level function signature. + if (func_.getNumArguments() != 3) { + return func_.emitError() + << "Invalid number of arguments to " << kCustomSSDPostprocessing + << ": " << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 4) { + return func_.emitError() + << "Invalid number of results from " << kCustomSSDPostprocessing + << ": " << func_.getType().getNumResults(); + } + return success(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.h b/tensorflow/compiler/mlir/lite/utils/nms_utils.h new file mode 100644 index 00000000000..6a9035e0c81 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with NMS ops in TFLite. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ + +#include + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +namespace mlir { +namespace TFL { + +// Abstracts the conversion of the padded NMS composite function. +class ConvertNMSPaddedFunc { + public: + explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {} + + void RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + FuncOp func_; +}; + +// Abstracts the conversion of the SSD post-processing composite function to +// TFLite. +class ConvertSSDPostProcessFunc { + public: + explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr) + : func_(func), attr_(attr) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs, + std::string& custom_option_buffer); + + LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + FuncOp func_; + mlir::TF::FuncAttr attr_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 96d22cb51e9..cce8038d3fa 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -47,6 +47,7 @@ namespace { constexpr char kNgrams[] = "tftext:Ngrams"; constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer"; +constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection"; constexpr char kTFImplements[] = "tf._implements"; using mlir::TF::FuncAttr; @@ -56,9 +57,9 @@ inline OpaqueElementsAttr CustomOption(OpBuilder* builder, const std::string& content) { ShapedType type = RankedTensorType::get( {static_cast(content.size())}, builder->getIntegerType(8)); - return OpaqueElementsAttr::get( - builder->getContext()->getRegisteredDialect("tfl"), type, - StringRef(content.data(), content.size())); + return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"), + type, + StringRef(content.data(), content.size())); } inline TensorType GetInputType(FuncOp func, int idx) { @@ -269,6 +270,85 @@ LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) { return success(); } +LogicalResult VerifySgnnProjection(FuncOp func, FuncAttr attr) { + if (func.getType().getNumInputs() != 2 || + func.getType().getNumResults() != 1) { + return func.emitError() << "Mismatched number of inputs and outputs."; + } + auto values_type = GetInputType(func, 0); + if (!values_type || !values_type.getElementType().isa()) { + return func.emitError() << "First input should be a string tensor"; + } + auto row_splits_type = GetInputType(func, 1); + if (!row_splits_type || + !row_splits_type.getElementType().isa()) { + return func.emitError() << "Second input should be an integer tensor"; + } + + auto hash_seed = + attr.GetAttrs().get("hash_seed").dyn_cast_or_null(); + if (!hash_seed) { + return func.emitError() + << "'hash_seed' attribute is not set or not an array"; + } + auto output_type = GetResultType(func, 0); + if (!output_type || !output_type.getElementType().isa() || + !RankEquals(output_type, 2)) { + return func.emitError() << "Output should be a 2D float tensor."; + } + if (output_type.getDimSize(1) != hash_seed.size()) { + return func.emitError() + << "Output 2nd dimension should be the num of hash seeds."; + } + + auto buckets = attr.GetAttrs().get("buckets").dyn_cast_or_null(); + if (!buckets) { + return func.emitError() << "'buckets' attribute is not set or not int"; + } + + return success(); +} + +LogicalResult CreateSgnnProjectionCustomOption( + FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) { + flexbuffers::Builder fbb; + size_t start_map = fbb.StartMap(); + + auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null(); + auto vector_start = fbb.StartVector("hash_seed"); + for (int i = 0; i < hash_seed.size(); i++) { + fbb.Add(static_cast( + (hash_seed.getValue().data() + i)->dyn_cast().getInt())); + } + fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false); + + auto buckets = attrs.get("buckets").dyn_cast_or_null(); + fbb.Int("buckets", buckets.getInt()); + + fbb.EndMap(start_map); + fbb.Finish(); + custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end()); + return success(); +} + +LogicalResult ConvertSgnnProjection(FuncOp func, llvm::StringRef api, + FuncAttr attr) { + // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py + func.eraseBody(); + func.addEntryBlock(); + func.setAttr(kTFImplements, attr); + OpBuilder builder(func.getBody()); + std::string custom_option_buffer; + if (failed(CreateSgnnProjectionCustomOption(func, attr.GetAttrs(), + custom_option_buffer))) { + return failure(); + } + auto op = builder.create( + func.getLoc(), func.getType().getResults(), func.getArguments(), api, + CustomOption(&builder, custom_option_buffer)); + builder.create(func.getLoc(), op.getResults()); + return success(); +} } // namespace LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api, @@ -281,6 +361,10 @@ LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api, if (succeeded(VerifyNgrams(func))) { return ConvertNgrams(func, api, attr); } + } else if (api.str() == kCustomSgnnProjection) { + if (succeeded(VerifySgnnProjection(func, attr))) { + return ConvertSgnnProjection(func, api, attr); + } } return failure(); } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 8e6d9042987..d97e12fbe45 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -91,16 +91,14 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { return *global; } -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +static void RegisterDialects(mlir::DialectRegistry& registry) { + // clang-format off + registry.insert(); + // clang-format on } Status MlirFunctionOptimizationPass::Run( @@ -126,12 +124,18 @@ Status MlirFunctionOptimizationPass::Run( << " passes)"; GraphDebugInfo debug_info; - RegisterDialects(); mlir::MLIRContext context; + RegisterDialects(context.getDialectRegistry()); GraphImportConfig import_config; import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; import_config.upgrade_legacy = true; + // Disable shape inference during import as some TensorFlow op fails during + // shape inference with dynamic shaped operands. This in turn causes the + // import to fail. Shape inference during import is going to be removed and + // the shape inference pass is run early in the pass pipeline, shape inference + // during import is not necessary. + import_config.enable_shape_inference = false; TF_ASSIGN_OR_RETURN(auto module_ref, ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context)); @@ -200,8 +204,8 @@ Status MlirV1CompatGraphOptimizationPass::Run( << " passes)"; GraphDebugInfo debug_info; - RegisterDialects(); mlir::MLIRContext context; + RegisterDialects(context.getDialectRegistry()); GraphImportConfig import_config; import_config.upgrade_legacy = true; // Restrict functionalization to TPU nodes to avoid problems in v1 session diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index bce0ed4a33d..6b605741355 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/utils/name_utils.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { return absl::string_view(ref.data(), ref.size()); @@ -103,62 +104,16 @@ int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) { bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; } -namespace { -// Derives name from location. -std::string GetNameFromLoc(mlir::Location loc) { - llvm::SmallVector loc_names; - llvm::SmallVector locs; - locs.push_back(loc); - bool names_is_nonempty = false; - - while (!locs.empty()) { - mlir::Location curr_loc = locs.pop_back_val(); - - if (auto name_loc = curr_loc.dyn_cast()) { - // Add name in NameLoc. For NameLoc we also account for names due to ops - // in functions where the op's name is first. - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } else if (auto call_loc = curr_loc.dyn_cast()) { - // Add name if CallSiteLoc's callee has a NameLoc (as should be the - // case if imported with DebugInfo). - if (auto name_loc = call_loc.getCallee().dyn_cast()) { - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } - } else if (auto fused_loc = curr_loc.dyn_cast()) { - // Push all locations in FusedLoc in reverse order, so locations are - // visited based on order in FusedLoc. - auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); - locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); - continue; - } - - // Location is not a supported, so an empty StringRef is added. - loc_names.push_back(llvm::StringRef()); - } - - if (names_is_nonempty) - return llvm::join(loc_names.begin(), loc_names.end(), ";"); - - return ""; -} -} // anonymous namespace - std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { if (auto* op = op_or_val.dyn_cast()) { - auto name_from_loc = GetNameFromLoc(op->getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(op->getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. return std::string(op->getName().getStringRef()); } auto val = op_or_val.dyn_cast(); - auto name_from_loc = GetNameFromLoc(val.getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(val.getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 5bbfba773a3..66283bded71 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -10,6 +10,7 @@ cc_library( deps = [ "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", @@ -35,6 +36,9 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 5ce0ca8cfcb..066726593a7 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -16,19 +16,53 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op.h" namespace tensorflow { +namespace { + +// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not +// empty. +std::string RunPassPipelineOnModule(mlir::ModuleOp module, + const std::string &pass_pipeline, + TF_Status *status) { + if (!pass_pipeline.empty()) { + mlir::PassManager pm(module.getContext()); + std::string error; + llvm::raw_string_ostream error_stream(error); + if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Invalid pass_pipeline: " + error_stream.str()).c_str()); + return "// error"; + } + + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext()); + if (failed(pm.run(module))) { + Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); + return "// error"; + } + } + return MlirModuleToString(module); +} + +} // anonymous namespace + std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status) { @@ -47,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto, return "// error"; } - // Run the pass_pipeline on the module if not empty. - if (!pass_pipeline.empty()) { - mlir::PassManager pm(&context); - std::string error; - llvm::raw_string_ostream error_stream(error); - if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - ("Invalid pass_pipeline: " + error_stream.str()).c_str()); - return "// error"; - } + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); +} - mlir::StatusScopedDiagnosticHandler statusHandler(&context); - if (failed(pm.run(*module.ValueOrDie()))) { - Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); - return "// error"; - } +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, + TF_Status *status) { + FunctionDef functiondef; + auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; } - return MlirModuleToString(*module.ConsumeValueOrDie()); + + FunctionDefLibrary fdef_lib; + s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + s = flib_def.AddFunctionDef(functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + const std::string &function_name = functiondef.signature().name(); + mlir::MLIRContext context; + auto module = ConvertFunctionToMlir(function_name, flib_def, &context); + if (!module.ok()) { + Set_TF_Status_from_Status(status, module.status()); + return "// error"; + } + + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); } std::string ExperimentalConvertSavedModelToMlir( @@ -150,6 +203,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, bool show_debug_info, TF_Status *status) { mlir::MLIRContext context; + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); mlir::OwningModuleRef module; { mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); @@ -164,6 +218,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, mlir::PassManager pm(&context); std::string error; llvm::raw_string_ostream error_stream(error); + mlir::registerAllPasses(); if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { TF_SetStatus(status, TF_INVALID_ARGUMENT, ("Invalid pass_pipeline: " + error_stream.str()).c_str()); diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index e68ac28124b..6133068a5e8 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -25,13 +25,23 @@ limitations under the License. namespace tensorflow { // Simple wrapper to support tf.mlir.experimental.convert_graph_def. -// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before -// returning it as a string. +// Load a GraphDef (binary or textual proto format), convert to MLIR, and +// (optionally) optimize the module before returning it as a string. // This is an early experimental API, ideally we should return a wrapper object // around a Python binding to the MLIR module. std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status); +// Simple wrapper to support tf.mlir.experimental.convert_function. +// Load FunctionDef and FunctionDefLibrary (binary or textual proto format), +// convert to MLIR, and (optionally) optimize the module before returning it as +// a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, TF_Status *status); + // Load a SavedModel and return a textual MLIR string corresponding to it. // // Args: diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 5e21dddd444..31bce8d1bf6 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -21,6 +21,7 @@ tf_python_pybind_extension( "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", "@llvm-project//llvm:Support", + "@llvm-project//llvm:filecheck-lib", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:StandardOps", @@ -37,6 +38,7 @@ tf_python_pybind_extension( "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", "@llvm-project//llvm:Support", + "@llvm-project//llvm:filecheck-lib", "@pybind11", ], ) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc index 25adb44fe1d..5ae638851f4 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/FileCheck.h" +#include "llvm/FileCheck/FileCheck.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc index 8a841856b72..051952ebaba 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/FileCheck.h" +#include "llvm/FileCheck/FileCheck.h" #include "llvm/Support/SourceMgr.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 63ca4c7bb28..6cd49cf368d 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -22,22 +22,25 @@ limitations under the License. #include "mlir/Parser.h" // from @llvm-project #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_status.h" PYBIND11_MODULE(mlir_wrapper, m) { - m.def("registerDialects", []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); + m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) { + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + context.getDialectRegistry().loadAll(&context); }); + m.def("verify", [](std::string input) { llvm::SourceMgr SM = llvm::SourceMgr(); SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), llvm::SMLoc()); mlir::MLIRContext ctx; + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); auto module = mlir::parseSourceFile(SM, &ctx); if (!module) { return false; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc index 2be67f8e93e..be2dc2065f3 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -20,11 +20,6 @@ limitations under the License. void init_types(py::module& m) { // Type py::class_ Type(m, "Type"); - Type.def("getKind", &mlir::Type::getKind); - - // Type Enums - py::enum_(Type, "StandardTypes_Kind") - .value("BF16", mlir::StandardTypes::BF16); // Type Sub-classes py::class_(m, "FunctionType") @@ -32,7 +27,10 @@ void init_types(py::module& m) { [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); py::class_(m, "FloatType") - .def("get", &mlir::FloatType::get); + .def("getBF16", &mlir::FloatType::getBF16) + .def("getF16", &mlir::FloatType::getF16) + .def("getF32", &mlir::FloatType::getF32) + .def("getF64", &mlir::FloatType::getF64); py::class_(m, "IntegerType") .def("get", py::overload_cast( diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 45c8dce8422..f9870183b88 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -74,7 +74,7 @@ tool_names = [ 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir', - 'kernel-gen-opt', 'xla-thunks-opt' + 'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c6f0083fc92..7bdc3b0396f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -13,6 +13,7 @@ package_group( "//learning/brain/experimental/dtensor/...", "//learning/brain/experimental/tfrt/...", "//learning/pathways/data_parallel/tf2xla/...", + "//platforms/xla/sparse_core/...", "//tensorflow/compiler/...", "//tensorflow/lite/experimental/tf_runtime/...", "//tensorflow/python/...", @@ -33,6 +34,7 @@ filegroup( "ir/tf_op_base.td", "ir/tf_op_interfaces.td", "ir/tf_ops.td", + "ir/tfrt_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", @@ -124,6 +126,25 @@ gentbl( ], ) +gentbl( + name = "tensorflow_tfrt_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls", + "ir/tfrt_ops.h.inc", + ), + ( + "-gen-op-defs", + "ir/tfrt_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tfrt_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], +) + # We only shard tf_op on name for build performance reasons. tf_ops_category_list = [ { @@ -343,6 +364,7 @@ cc_library( name = "tensorflow_" + target["name"], srcs = [ "ir/tf_ops.h", + "ir/tfrt_ops.h", "ir/tf_remaining_ops.h", "ir/tf_" + target["name"] + ".cc", "ir/tf_" + target["name"] + ".cc.inc", @@ -352,9 +374,11 @@ cc_library( textual_hdrs = [ "ir/tf_all_ops.h.inc", "ir/tf_ops_helpers.inc", + "ir/tfrt_ops.h.inc", "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ + ":attribute_utils", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_op_interfaces", @@ -385,6 +409,7 @@ cc_library( "ir/tf_ops.h", "ir/tf_remaining_ops.h", "ir/tf_remaining_ops.cc", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ ], @@ -392,6 +417,49 @@ cc_library( "ir/tf_all_ops.h.inc", "ir/tf_ops_helpers.inc", "ir/tf_remaining_ops.h.inc", + "ir/tfrt_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], + deps = [ + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_remaining_ops_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", + ":tensorflow_traits", + ":tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tensorflow_tfrt_ops", + srcs = [ + "ir/tf_ops.h", + "ir/tfrt_ops.h", + "ir/tfrt_ops.cc", + "ir/tf_remaining_ops.h", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + hdrs = [ + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tfrt_ops.h.inc", + "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ ":tensorflow_attributes", @@ -401,6 +469,7 @@ cc_library( ":tensorflow_remaining_ops_inc_gen", ":tensorflow_side_effects", ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_traits", ":tensorflow_types", "//tensorflow/core:framework", @@ -427,9 +496,11 @@ cc_library( textual_hdrs = [ "ir/tf_all_ops.h.inc", "ir/tf_remaining_ops.h", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], deps = [ ":tensorflow_all_ops_inc_gen", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_remaining_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", @@ -440,6 +511,7 @@ cc_library( ":tensorflow_traits", ":tensorflow_types", ":tensorflow_remaining_ops", + ":tensorflow_tfrt_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", @@ -512,6 +584,7 @@ cc_library( "ir/tf_saved_model.cc", ], hdrs = [ + "dialect_registration.h", "ir/tf_device.h", "ir/tf_executor.h", "ir/tf_ops.h", @@ -536,6 +609,7 @@ cc_library( ":tensorflow_ops", ":tensorflow_side_effects", ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_traits", ":tensorflow_types", ":tf_saved_model_inc_gen", @@ -718,12 +792,13 @@ cc_library( deps = [ ":tensorflow", ":tensorflow_types", - "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/core:framework", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) @@ -738,6 +813,7 @@ cc_library( "transforms/cluster_formation.cc", "transforms/cluster_outlining.cc", "transforms/collection_ops_util.cc", + "transforms/contraction_fusion.cc", "transforms/decompose_resource_ops_pass.cc", "transforms/device_index_selector.cc", "transforms/einsum.cc", @@ -769,24 +845,34 @@ cc_library( "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", "transforms/resource_op_lifting.cc", + "transforms/resource_op_lifting_cleanup.cc", + "transforms/resource_op_lifting_cleanup.h", "transforms/rewrite_tpu_embedding_ops.cc", "transforms/shape_inference.cc", "transforms/shape_inference_pass.cc", "transforms/sink_constant.cc", "transforms/stack_ops_decomposition.cc", "transforms/tensor_array_ops_decomposition.cc", + "transforms/tensor_device_copy_conversion.cc", "transforms/tensor_list_ops_decomposition.cc", + "transforms/test_resource_alias_analysis.cc", "transforms/test_side_effect_analysis.cc", + "transforms/test_visitor_util.cc", "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", + "transforms/tpu_cluster_cleanup_attributes.cc", "transforms/tpu_cluster_formation.cc", + "transforms/tpu_colocate_composite_resource_ops.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", "transforms/tpu_extract_head_tail_outside_compilation.cc", "transforms/tpu_extract_outside_compilation.cc", "transforms/tpu_host_computation_expansion.cc", + "transforms/tpu_identity_pruning.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_outside_compilation_cluster.cc", + "transforms/tpu_parallel_execute_sink_resource_write.cc", + "transforms/tpu_resource_read_for_write.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_space_to_depth_pass.cc", @@ -797,8 +883,6 @@ cc_library( "translate/tf_functional_to_executor.cc", ], hdrs = [ - "transforms/attribute_utils.h", - "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", "transforms/collection_ops_util.h", "transforms/einsum.h", @@ -806,7 +890,11 @@ cc_library( "transforms/shape_inference.h", ], includes = ["include"], + textual_hdrs = [ + "ir/tf_ops_helpers.inc", + ], deps = [ + ":attribute_utils", ":bridge_logger", ":convert_tensor", ":convert_type", @@ -815,7 +903,10 @@ cc_library( ":device_util", ":error_util", ":export_tf_dialect_op", + ":lower_tf_lib", ":mangling_util", + ":serialize_mlir_module_utils", + ":shape_inference_utils", ":tensorflow", ":tensorflow_analysis", ":tensorflow_optimize_inc_gen", @@ -824,9 +915,12 @@ cc_library( ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", + ":visitor_util", ":xla_sharding_util", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:sharding_builder", @@ -843,6 +937,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", @@ -909,17 +1004,6 @@ cc_library( alwayslink = 1, ) -# Library with TensorFlow dialect static initialization. -cc_library( - name = "tensorflow_dialect_registration", - srcs = ["ir/dialect_registration.cc"], - deps = [ - ":tensorflow", - "@llvm-project//mlir:Shape", - ], - alwayslink = 1, -) - cc_library( name = "convert_graphdef", srcs = [ @@ -949,6 +1033,7 @@ cc_library( "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_util", "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/xla:status_macros", @@ -1064,6 +1149,7 @@ cc_library( ":export_utils", ":tensorflow", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", @@ -1079,6 +1165,7 @@ cc_library( srcs = ["translate/translate_tf_dialect_op.cc"], deps = [ ":export_tf_dialect_op", + ":tensorflow", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1264,7 +1351,7 @@ cc_library( name = "tf_dialect_passes", srcs = [ "transforms/constant_fold.cc", - "transforms/dialect_hooks.cc", + "transforms/decode_attributes_hook.cc", ], hdrs = [ "transforms/constant_fold.h", @@ -1292,9 +1379,8 @@ cc_library( cc_library( name = "tf_dialect_lib", deps = [ - ":tensorflow_dialect_registration", ":tf_dialect_passes", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) @@ -1305,6 +1391,7 @@ cc_library( deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", + ":tensorflow", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -1399,6 +1486,7 @@ cc_library( deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", + ":tensorflow", ":translate_cl_options", ":translate_lib", "//tensorflow/core:protos_all_cc", @@ -1481,25 +1569,27 @@ gentbl( COMPILE_MLIR_UTIL_DEPS = [ ":bridge_logger", ":convert_graphdef", + ":convert_tensor", ":convert_type", ":dump_mlir_util", ":error_util", ":mlir_roundtrip_flags", + ":serialize_mlir_module_utils", ":tensorflow", - ":tensorflow_dialect_registration", ":tensorflow_types", ":tensorflow_passes", ":translate_utils", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "//tensorflow/compiler/mlir/hlo:hlo", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:type_to_shape", @@ -1514,9 +1604,9 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - ":convert_tensor", ] # Prefer to link 'compile_mlir_util' library that also links necessary @@ -1543,27 +1633,61 @@ cc_library( ], ) -tf_cc_test( - name = "compile_mlir_util_test", - size = "small", - srcs = ["utils/compile_mlir_util_test.cc"], +cc_library( + name = "compile_mlir_util_pass", + srcs = ["utils/compile_mlir_util_pass.cc"], deps = [ ":compile_mlir_util", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:Pass", ], + alwayslink = 1, +) + +cc_library( + name = "serialize_mlir_module_utils", + srcs = ["utils/serialize_mlir_module_utils.cc"], + hdrs = ["utils/serialize_mlir_module_utils.h"], + deps = [ + ":error_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "tf_xla_mlir_translate", + srcs = ["utils/tf_xla_mlir_translate.cc"], + deps = [ + ":compile_mlir_util", + ":mlir_roundtrip_flags", + ":serialize_mlir_module_utils", + ":tensorflow", + ":translate_cl_options", + "//tensorflow/compiler/mlir:string_container_utils", + "//tensorflow/compiler/mlir/xla:translate_cl_options", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, ) cc_library( @@ -1627,6 +1751,7 @@ cc_library( deps = [ ":lower_tf_inc_gen", ":tensorflow", + ":tensorflow_ops", ":tensorflow_types", "//tensorflow/core:framework", "@llvm-project//llvm:Support", @@ -1679,6 +1804,7 @@ cc_library( ":tensorflow", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1738,14 +1864,13 @@ cc_library( ":convert_graphdef", ":error_util", ":tensorflow", - ":tensorflow_dialect_registration", ":tensorflow_passes", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1761,6 +1886,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", @@ -1781,6 +1907,21 @@ cc_library( ], ) +cc_library( + name = "visitor_util", + srcs = [ + "utils/visitor_util.cc", + ], + hdrs = [ + "utils/visitor_util.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "xla_sharding_util", srcs = [ @@ -1798,3 +1939,35 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "attribute_utils", + hdrs = ["utils/attribute_utils.h"], + deps = [ + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "shape_inference_utils", + srcs = ["utils/shape_inference_utils.cc"], + hdrs = ["utils/shape_inference_utils.h"], + deps = [ + ":convert_tensor", + ":convert_type", + ":export_utils", + ":tensorflow", + ":tensorflow_attributes", + ":tensorflow_types", + "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 3278c06fabe..d70bd01e490 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -21,11 +21,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -34,19 +37,19 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { +namespace detail { -namespace { //===----------------------------------------------------------------------===// // BacktrackAnalysisInfo //===----------------------------------------------------------------------===// @@ -86,9 +89,6 @@ class BacktrackAnalysisInfo { // Backtracked values indexed by the result number. llvm::SmallVector backtracked_values_; }; -} // namespace - -namespace detail { //===----------------------------------------------------------------------===// // BacktrackAnalysis @@ -137,12 +137,46 @@ class BacktrackAnalysis { return GetAnalysisForRegion(region); } + // Returns the backtrack analysis for the given region if it exists. + // If the region has not yet been analyzed, returns llvm::None. + Optional GetAnalysisIfExists(Region& region) const { + auto it = info_map_.find(®ion); + if (it == info_map_.end()) return llvm::None; + return &it->second; + } + + Optional GetAnalysisIfExists(FuncOp func) const { + return GetAnalysisIfExists(func.getBody()); + } + private: llvm::SmallDenseMap info_map_; }; // Analyzes all regions attached to all operations in the module. BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) { + const CallGraph call_graph(module); + + // Visit functions bottom up when doing the analysis. Note that SCC iterator + // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited + // after SCC2, i.e., the graph is traversed bottom up just the way we want. + auto scc_begin = llvm::scc_begin(&call_graph); + auto scc_end = llvm::scc_end(&call_graph); + for (auto& scc : make_range(scc_begin, scc_end)) { + // Each SCC node is a collection of callgraph nodes that form a cycle. We + // will visit these nodes in an arbitrary order. If a node being visited + // calls a function that has not yet been analyzed, we will not be able to + // backtrack through that function call (our analysis will be correct but + // pessimistic). + for (CallGraphNode* node : scc) { + if (node->isExternal()) continue; + Region* region = node->getCallableRegion(); + GetOrCreateAnalysis(*region); + } + } + + // This above call graph analysis will cover all regions attached to functions + // but we also need to analyze regions attached to other ops. module.walk([this](Operation* op) { for (Region& region : op->getRegions()) GetOrCreateAnalysis(region); }); @@ -161,17 +195,26 @@ Value BacktrackAnalysis::BacktrackValue(Value value) { // in the Island body. if (value == island.control()) break; value = island.GetYield().getOperand(res_index); - } else if (isa(op)) { + } else if (isa(op)) { value = op->getOperand(res_index); + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) break; + // Check if the function being called has been analyzed. if not, + // we cannot backtrack the value further. + Optional callee_info = GetAnalysisIfExists(func); + if (!callee_info) break; + Optional passthrough_arg = callee_info.getValue()->GetArg(res_index); + if (!passthrough_arg) break; + value = call.getArgOperands()[passthrough_arg.getValue()]; + } else if (isa(op)) { + value = op->getRegion(0).front().getTerminator()->getOperand(res_index); } else { break; } } return value; } -} // namespace detail - -namespace { // Analyze the region. BacktrackAnalysisInfo::BacktrackAnalysisInfo( @@ -188,6 +231,8 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo( backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result)); } +namespace { + //===----------------------------------------------------------------------===// // ResourceAliasAnalysisInfo helper functions. //===----------------------------------------------------------------------===// @@ -196,12 +241,12 @@ constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; // Returns if a VarHandleOp is anonymous, which means it always creates a new // variable. -bool IsResourceHandleAnonymous(TF::VarHandleOp handle) { +bool IsResourceHandleAnonymous(VarHandleOp handle) { return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; } // Returns a string unique identifier for a non-anonymous VarHandleOp. -std::string GetVarHandleStringId(TF::VarHandleOp handle) { +std::string GetVarHandleStringId(VarHandleOp handle) { auto device = handle.getAttrOfType("device"); return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), "/", device ? device.getValue().str() : std::string("")); @@ -210,7 +255,7 @@ std::string GetVarHandleStringId(TF::VarHandleOp handle) { // Finds a unique ID for a VarHandleOp's output. If it is anonymous, always // creates a new ID; otherwise, tries to reuse the existing ID for the // referenced variable if it exists, or creates a new one if not. -int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id, +int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t* next_id, llvm::StringMap* name_id_map) { // Always create a new ID for anonymous handle. if (IsResourceHandleAnonymous(handle)) return (*next_id)++; @@ -224,131 +269,269 @@ int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id, } // namespace -namespace detail { //===----------------------------------------------------------------------===// // ResourceAliasAnalysisInfo //===----------------------------------------------------------------------===// +constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId; + // Constructs the analysis info by analyzing the given function. ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( - FuncOp func_op, const detail::BacktrackAnalysis& backtrack_analysis) { + FuncOp func_op, const BacktrackAnalysis& backtrack_analysis) { // This function populates resource_value_to_ids_ and id_to_resource_values_. + int64_t next_unique_id = 0; + + // Helper to assign new unique id for all resources in the given list of + // values. + auto assign_unique_id_to_all = [&](ValueRange values) { + for (Value value : filter_resources(values)) { + AddValueUniqueIDMapping(value, next_unique_id++); + } + }; + + // Helper to assign new unknown id for all resources in the given list of + // values. + auto assign_unknown_id_to_all = [&](ValueRange values) { + for (Value value : filter_resources(values)) { + AddValueUniqueIDMapping(value, kUnknownResourceId); + } + }; + // If the "tf.resource_arg_unique_id" argument attributes are present for // resource-type arguments, respect them when choosing IDs; otherwise, they // must not alias. - int64_t next_unique_id = 0; const bool has_arg_unique_id_attrs = llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) { return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr); }); // Maps the kResourceArgUniqueIdAttr attribute value to the internal integer // ID used by this pass. - llvm::SmallDenseMap attr_id_to_internal_id; - for (auto arg : func_op.getArguments()) { - if (!mlir::getElementTypeOrSelf(arg.getType()).isa()) - continue; - if (has_arg_unique_id_attrs) { + if (has_arg_unique_id_attrs) { + llvm::SmallDenseMap attr_id_to_internal_id; + for (auto arg : filter_resources(func_op.getArguments())) { auto id_attr = func_op.getArgAttrOfType( arg.getArgNumber(), kResourceArgUniqueIdAttr); assert(id_attr && - "tf.resource_arg_unique_id attribute should exist on either none " - "or all arguments."); + "tf.resource_arg_unique_id attribute should exist on either " + "none or all arguments."); auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(), next_unique_id++); AddValueUniqueIDMapping(arg, emplace_res.first->getSecond()); - } else { - AddValueUniqueIDMapping(arg, next_unique_id++); } + } else { + assign_unique_id_to_all(func_op.getArguments()); } - llvm::StringMap var_handle_name_id_map; - auto forward_input_to_output = [&](const Value& operand, - const Value& result) { - if (!mlir::getElementTypeOrSelf(result.getType()).isa()) - return; - auto& result_ids = resource_value_to_ids_[result]; - auto operand_it = resource_value_to_ids_.find(operand); - assert(operand_it != resource_value_to_ids_.end() && - "A resource-type output does not have the corresponding " - "resource-type input."); - result_ids.insert(operand_it->getSecond().begin(), - operand_it->getSecond().end()); - }; + // Since this analysis is neither inter-procedural nor inter-regional, + // each region attached to Op's within a function is analyzed independently. + // Seed this analysis for each such region by mapping all resource arguments + // for such regions to a new unique-id. This is required because walk() walks + // the attached regions first before visiting the op, so there is no + // opportunity during the walk to seed region arguments. Also note that walk + // eventually also visits the Op on which the walk() is called, so make sure + // we do not overwrite the function argument mapping here. func_op.walk([&](Operation* op) { - if (auto var_handle = llvm::dyn_cast(op)) { + if (op == func_op) return; + for (Region& region : op->getRegions()) { + assign_unique_id_to_all(region.getArguments()); + } + }); + + llvm::StringMap var_handle_name_id_map; + func_op.walk([&](Operation* op) { + if (auto var_handle = dyn_cast(op)) { AddValueUniqueIDMapping( var_handle.resource(), GetOrCreateIdForVarHandle(var_handle, &next_unique_id, &var_handle_name_id_map)); - } else if (llvm::isa(op)) { - for (auto operand_and_result : - llvm::zip(op->getOperands(), op->getResults())) { - forward_input_to_output(std::get<0>(operand_and_result), - std::get<1>(operand_and_result)); + } else if (llvm::isa(op)) { + for (auto result : filter_resources(op->getResults())) + PropagateInputToOutput(op->getOperand(result.getResultNumber()), + result); + } else if (auto while_op = dyn_cast(op)) { + AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc( + while_op.body_function())); + } else if (auto while_region = dyn_cast(op)) { + AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion( + while_region.body())); + } else if (auto case_op = dyn_cast(op)) { + llvm::SmallVector functions; + functions.reserve(case_op.branches().size()); + for (auto branch : case_op.branches()) + functions.emplace_back(SymbolTable::lookupNearestSymbolFrom( + case_op, branch.cast())); + + AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis); + } else if (auto if_op = dyn_cast(op)) { + AnalyzeFunctionalCaseOrIfOp( + if_op, {if_op.then_function(), if_op.else_function()}, + backtrack_analysis); + } else if (llvm::isa(op)) { + AnalyzeRegionCaseOrIfOp(op, backtrack_analysis); + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) { + assign_unknown_id_to_all(op->getResults()); + return WalkResult::advance(); } - } else if (auto replicate = llvm::dyn_cast(op)) { - // The nested block for ReplicateOp is handled separately in side-effect - // analysis. Inside that block, we can still treat its block arguments as - // different resources. - for (auto arg : replicate.GetBody().getArguments()) { - if (mlir::getElementTypeOrSelf(arg.getType()).isa()) { - AddValueUniqueIDMapping(arg, next_unique_id++); - } - } - } else if (auto while_op = llvm::dyn_cast(op)) { - const auto& body_info = - backtrack_analysis.GetAnalysisForFunc(while_op.body_func()); - // If a result is a passthrough of the body input, use the corresponding - // operand's resource IDs. - for (auto result : llvm::enumerate(while_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value().getType()) - .isa()) { - continue; - } - auto passthrough_arg = body_info.GetArg(result.index()); + const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func); + for (auto result : filter_resources(op->getResults())) { + auto passthrough_arg = func_info.GetArg(result.getResultNumber()); if (passthrough_arg) { - forward_input_to_output( - while_op.getOperand(passthrough_arg.getValue()), result.value()); + PropagateInputToOutput( + call.getArgOperands()[passthrough_arg.getValue()], result); } else { - AddValueUniqueIDMapping(result.value(), kUnknownResourceId); + AddValueUniqueIDMapping(result, kUnknownResourceId); } } - } else if (auto if_op = llvm::dyn_cast(op)) { - const auto& then_info = - backtrack_analysis.GetAnalysisForFunc(if_op.then_func()); - const auto& else_info = - backtrack_analysis.GetAnalysisForFunc(if_op.else_func()); - // If a result is a passthrough of both branches' inputs, merge the - // resource IDs of corresponding operands for the two inputs. - for (auto result : llvm::enumerate(if_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value().getType()) - .isa()) { - continue; - } - auto passthrough_then_arg = then_info.GetArg(result.index()); - auto passthrough_else_arg = else_info.GetArg(result.index()); - if (passthrough_then_arg && passthrough_else_arg) { - Value then_operand = if_op.input()[passthrough_then_arg.getValue()]; - Value else_operand = if_op.input()[passthrough_else_arg.getValue()]; - forward_input_to_output(then_operand, result.value()); - forward_input_to_output(else_operand, result.value()); - } else { - AddValueUniqueIDMapping(result.value(), kUnknownResourceId); - } + } else if (isa(op)) { + Region& region = op->getRegion(0); + const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region); + for (auto result : filter_resources(op->getResults())) { + Value body_result = body_info.GetValue(result.getResultNumber()); + PropagateInputToOutput(body_result, result); } } else { - for (auto result : op->getResults()) { - if (!mlir::getElementTypeOrSelf(result.getType()) - .isa()) - continue; - AddValueUniqueIDMapping(result, kUnknownResourceId); - } + assign_unknown_id_to_all(op->getResults()); } + return WalkResult::advance(); }); } -bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const { +// Propagates the resource ID's from an input operand to a result. Returns true +// if the mapping changed. +bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand, + const OpResult& result) { + auto operand_it = resource_value_to_ids_.find(operand); + assert(operand_it != resource_value_to_ids_.end() && + "A resource-type output does not have the corresponding " + "resource-type input."); + bool change = false; + for (int64_t id : operand_it->second) + change = AddValueUniqueIDMapping(result, id) || change; + return change; +} + +// Analyzes while loops to compute resourceIDs for the loop results. +// +// (1) The base case for the analysis is that if the loop body does not execute +// at all, the resource IDs for each result is the same as the resource IDs +// of the corresponding input. +// (2) If the loop does execute one or more times, then we need to account for +// data flow through the body of the while loop. If result #r is the same +// as arg #a of the loop body (pass through argument), then we can reason +// further, else if the result is not a passthrough, we mark it as unknown. +// (3) For passthrough results, if result #r is the same as arg #a of the loop +// body, after one iteration, result #r = arg #a, so we need to also +// propagate arg #a to result #r. After another iteration, arg #a of the +// loop body will be result #a of the previous iteration. So then we need +// propagate from result #a to result #r. Generalizing, the resource ID +// propagation (for results which are passthrough) looks like: +// +// for r in (0, num_results) : result[r] = arg[r]; +// repeat till no change { +// a = passthrough arg for result #r; +// result[r] += result[a]; +// } +// +void ResourceAliasAnalysisInfo::AnalyzeWhileLoop( + Operation* while_op, const BacktrackAnalysisInfo& body_info) { + // Seed the resource ID's for the results using either the resource ID of the + // passthrough arg, or unknown. We need to perform further analysis if we + // find a passthrough arg which is not the same as corresponding the result #. + llvm::SmallVector, 4> passthrough_args( + while_op->getNumResults()); + bool need_analysis = false; + for (auto result : filter_resources(while_op->getResults())) { + int result_index = result.getResultNumber(); + passthrough_args[result_index] = body_info.GetArg(result_index); + if (passthrough_args[result_index]) { + int passthru_index = passthrough_args[result_index].getValue(); + PropagateInputToOutput(while_op->getOperand(passthru_index), result); + need_analysis |= + !IsUnknownResource(result) && passthru_index != result_index; + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } + + if (!need_analysis) return; + + // We found a result that is not unknown and whose passthrough operand index + // is not the same as the result index, which means there is "crosstalk" + // between 2 or more operands. In that case, we do an iterative propagation + // of resource ID's till the results converge. + bool change = true; + while (change) { + change = false; + for (auto result : filter_resources(while_op->getResults())) { + if (IsUnknownResource(result)) continue; + // If this result has a valid passthrough arg, propagate resource ID's + // from the result of the passthrough arg + int result_index = result.getResultNumber(); + int passthru_index = passthrough_args[result_index].getValue(); + change = + PropagateInputToOutput(while_op->getResult(passthru_index), result) || + change; + } + } +} + +template +void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp( + CaseOrIfOp case_or_if_op, llvm::ArrayRef functions, + const BacktrackAnalysis& backtrack_analysis) { + llvm::SmallVector infos; + infos.reserve(functions.size()); + for (FuncOp func : functions) + infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func)); + + // If a result is a passthrough of all branches' inputs, merge the resource + // IDs of corresponding operands for all the inputs. + for (auto result : filter_resources(case_or_if_op.getResults())) { + llvm::SmallVector, 2> passthrough_args; + passthrough_args.reserve(functions.size()); + for (const auto* info : infos) + passthrough_args.emplace_back(info->GetArg(result.getResultNumber())); + + const bool all_passthrough_args_known = llvm::all_of( + passthrough_args, [](const llvm::Optional& passthrough_arg) { + return passthrough_arg.hasValue(); + }); + if (all_passthrough_args_known) { + for (const auto& passthrough_arg : passthrough_args) { + Value operand = case_or_if_op.input()[passthrough_arg.getValue()]; + PropagateInputToOutput(operand, result); + } + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } +} + +void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp( + Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) { + llvm::SmallVector infos; + infos.reserve(case_or_if_op->getNumRegions()); + for (Region& region : case_or_if_op->getRegions()) + infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region)); + + // For region Case/If, the walk would have visited all branch regions before + // visiting the Case/If op. Backtracking of each region results will either + // give a value computed within these regions, or a region capture. If it is a + // region capture computed before this Case/If, it will have been visited + // earlier and a mapping would exist for that value. If it is computed within + // the region, then again a mapping would exist. + for (auto result : filter_resources(case_or_if_op->getResults())) { + for (const auto* info : infos) { + Value region_result = info->GetValue(result.getResultNumber()); + PropagateInputToOutput(region_result, result); + } + } +} + +bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const { auto it = resource_value_to_ids_.find(resource); assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); // The set is sorted so we only need to check the first element since @@ -360,6 +543,7 @@ bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const { const llvm::SmallSet& ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const { + assert(!IsUnknownResource(resource)); auto it = resource_value_to_ids_.find(resource); assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); return it->getSecond(); @@ -373,14 +557,19 @@ ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const { } llvm::SmallSetVector ResourceAliasAnalysisInfo::GetResourceAliases( - const Value resource) const { - assert(!IsUnknownResource(resource) && "Unseen resource was queried"); + Value resource) const { + assert(!IsUnknownResource(resource) && "Unknown resource was queried"); llvm::SmallSetVector aliases; for (int64_t id : GetResourceUniqueIds(resource)) { const llvm::SmallSetVector& resources_aliasing_id = GetUniqueIdResources(id); aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end()); } + // If there are resources that were marked as unknown, they alias with all + // other resources. + auto it = id_to_resource_values_.find(kUnknownResourceId); + if (it != id_to_resource_values_.end()) + aliases.insert(it->getSecond().begin(), it->getSecond().end()); return aliases; } @@ -390,10 +579,7 @@ llvm::SmallSetVector ResourceAliasAnalysisInfo::GetResourceAliases( // ResourceAliasAnalysis //===----------------------------------------------------------------------===// -ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { - auto module = dyn_cast(op); - assert(module); - +ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) { // Analyze all regions for backtracking info. detail::BacktrackAnalysis backtrack_analysis(module); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h index 5a514a7fb64..5575767dcc4 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -20,18 +20,23 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { namespace TF { namespace detail { class BacktrackAnalysis; +class BacktrackAnalysisInfo; // Resource alias analysis information for a single function. class ResourceAliasAnalysisInfo { @@ -43,7 +48,7 @@ class ResourceAliasAnalysisInfo { ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default; // Returns if the analysis fails to resolve a resource-type value. - bool IsUnknownResource(const Value resource) const; + bool IsUnknownResource(Value resource) const; // Returns the set unique IDs which `resource` could alias. Requires that // IsUnknownResource(resource) == false. @@ -54,15 +59,35 @@ class ResourceAliasAnalysisInfo { llvm::SmallSetVector GetResourceAliases(Value resource) const; private: - // Maps resource value to unique ID and vice-versa. - void AddValueUniqueIDMapping(Value value, int64_t id) { + // Maps resource value to unique ID and vice-versa. Returns true of the + // mapping has changed. + bool AddValueUniqueIDMapping(Value value, int64_t id) { resource_value_to_ids_[value].insert(id); - id_to_resource_values_[id].insert(value); + return id_to_resource_values_[id].insert(value); } // Returns the set unique Values which map to `id`. const llvm::SmallSetVector& GetUniqueIdResources(int64_t id) const; + // Propagates the resource ID's from an input operand to a result. Returns + // true of the mapping has changed. + bool PropagateInputToOutput(const Value& operand, const OpResult& result); + + // Analyzes while loops to compute resourceID's for the loop results. + // `body_info` is the backtrack analysis info for the loop body. + void AnalyzeWhileLoop(Operation* while_op, + const BacktrackAnalysisInfo& body_info); + + // Analyzes tf.Case/tf.If ops to compute resourceID's. + template + void AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op, + llvm::ArrayRef functions, + const BacktrackAnalysis& backtrack_analysis); + + // Analyzes tf.CaseRegion/tf.IfRegion ops to compute resourceID's. + void AnalyzeRegionCaseOrIfOp(Operation* case_or_if_op, + const BacktrackAnalysis& backtrack_analysis); + // Maps each resource-type value to a set of unique IDs that it could alias. llvm::SmallDenseMap, 8> resource_value_to_ids_; @@ -88,9 +113,18 @@ class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis< detail::ResourceAliasAnalysisInfo> { public: // Constructs analysis by analyzing the given module operation. - explicit ResourceAliasAnalysis(Operation* op); + explicit ResourceAliasAnalysis(ModuleOp module); }; +// Returns a range with just resource type values from the input range +// preserved. +template +auto filter_resources(RangeT&& range) { + return llvm::make_filter_range(std::forward(range), [](Value val) { + return getElementTypeOrSelf(val.getType()).isa(); + }); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 9e78b90debc..4a2080c5951 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -21,27 +21,29 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/tf2xla/resource_operation_table.h" -#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -67,16 +69,12 @@ llvm::SmallDenseSet FindAccessedResources( Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) { llvm::SmallDenseSet resources; - for (auto operand : op->getOperands()) { - if (!mlir::getElementTypeOrSelf(operand.getType()).isa()) - continue; + for (auto operand : filter_resources(op->getOperands())) { if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet(); const auto& ids = alias_analysis.GetResourceUniqueIds(operand); resources.insert(ids.begin(), ids.end()); } - for (auto result : op->getResults()) { - if (!mlir::getElementTypeOrSelf(result.getType()).isa()) - continue; + for (auto result : filter_resources(op->getResults())) { if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet(); const auto& ids = alias_analysis.GetResourceUniqueIds(result); resources.insert(ids.begin(), ids.end()); @@ -84,67 +82,139 @@ llvm::SmallDenseSet FindAccessedResources( return resources; } -// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies -// the resource access type of the op. It tells whether the op is read only, -// etc. -// -// TODO(yuanzx): Define this information in a different place. Currently we use -// tensorflow/compiler/tf2xla/resource_operation_table.h. -const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { - if (op->getName().getDialect() != - TF::TensorFlowDialect::getDialectNamespace()) { - return nullptr; +// Helper struct defining what memory effects are present for a resource. +struct SideEffects { + bool alloc = false; + bool free = false; + bool read = false; + bool write = false; + + bool IsAllocOnly() const { return alloc && !free && !read && !write; } + bool IsReadOnly() const { return !alloc && !free && read && !write; } +}; + +using ResourceSideEffectsByValue = llvm::SmallDenseMap; + +// Collects memory side effects for an operation by value (operands and +// results). +ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) { + ResourceSideEffectsByValue resource_info; + + auto interface = dyn_cast(op); + if (!interface) return resource_info; + + llvm::SmallVector effects; + interface.getEffects(effects); + + for (auto& effect : effects) { + // TODO(lyandy): Support effects with no value defined. + if (!effect.getValue()) return ResourceSideEffectsByValue(); + auto it = resource_info.try_emplace(effect.getValue()); + auto& side_effect = it.first->getSecond(); + auto* resource_effect = effect.getEffect(); + if (isa(resource_effect)) { + side_effect.alloc = true; + } else if (isa(resource_effect)) { + side_effect.free = true; + } else if (isa(resource_effect)) { + side_effect.read = true; + } else if (isa(resource_effect)) { + side_effect.write = true; + } else { + return ResourceSideEffectsByValue(); + } } - return tensorflow::GetResourceOpInfoForOp( - op->getName().getStringRef().split('.').second.str()); + + return resource_info; } -// Returns whether `op` accesses resources and it is known to be read-only. -bool OpIsReadOnly(Operation* op) { - auto resource_op_info = GetResourceInfoForOp(op); - return resource_op_info && - resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead; +// Checks if a value is a result of `op`. +bool IsOperationResult(Operation* op, Value value) { + return value.getDefiningOp() == op; +} + +// Checks if an operation's resource operands are read only. Operation results +// are ignored. +bool IsResourceOpReadOnly(Operation* op, + const ResourceSideEffectsByValue& resource_op_info) { + if (resource_op_info.empty()) return false; + + for (const auto& resource_info : resource_op_info) { + Value value = resource_info.getFirst(); + if (IsOperationResult(op, value)) continue; + const SideEffects& side_effects = resource_info.getSecond(); + if (!side_effects.IsReadOnly()) return false; + } + + return true; +} + +// Checks if an operation's resource results are alloc only and no side effects +// are present for its operands. +bool IsResourceOpAllocOnly(Operation* op, + const ResourceSideEffectsByValue& resource_op_info) { + if (resource_op_info.empty()) return false; + + for (const auto& resource_info : resource_op_info) { + // Operand with side effect. + Value value = resource_info.getFirst(); + if (!IsOperationResult(op, value)) return false; + const SideEffects& side_effects = resource_info.getSecond(); + if (!side_effects.IsAllocOnly()) return false; + } + + return true; } // Returns if `op` is a resource declaration. bool OpIsDeclaration(Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) { - // TODO(yuanzx): Add other types of resources. - return llvm::isa(op) || - (llvm::isa(op) && - !FindAccessedResources(op, alias_analysis).empty()); + return llvm::isa(op) && + !FindAccessedResources(op, alias_analysis).empty(); } -// Returns if `op` is know to not have any side effect. -bool OpIsKnownToHaveNoSideEffect(Operation* op) { - // TODO(riverriddle) We shouldn't treat all terminator operations as having - // side effects, this should be relaxed. - // TODO(riverriddle) Properly handle region side effects. - if (MemoryEffectOpInterface::hasNoEffect(op) && op->isKnownNonTerminator() && - op->getNumRegions() == 0) { - return true; - } - if (auto if_op = llvm::dyn_cast(op)) { - return if_op.is_stateless(); - } - if (auto while_op = llvm::dyn_cast(op)) { - return while_op.is_stateless(); - } +// A vector of resource variable id's with their associated resource value. +using ResourceIdsByValue = + llvm::SmallVector*>, 4>; - // Try to get the statefulness flag from the registry. - // - // TODO(yuanzx): Remove this after all ops are defined in the dialect. - if (op->getName().getDialect() != - TF::TensorFlowDialect::getDialectNamespace()) { - return false; - } - StringRef op_name = op->getName().getStringRef(); - // Drop the `tf.` prefix to query TF registry. - auto node_name = - op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1); - const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(node_name.data()); - return op_reg_data && !op_reg_data->op_def.is_stateful(); +// Collects resource id's by resource value. If operation resource side effects +// are unknown or a resource is unknown, an empty optional is returned. +llvm::Optional GetResourceIdsByValue( + Operation* op, const ResourceAliasAnalysis::Info& alias_analysis, + const ResourceSideEffectsByValue& resource_op_info) { + ResourceIdsByValue resource_ids_by_value; + if (resource_op_info.empty()) return llvm::None; + + auto collect_ids = [&](ValueRange values) { + for (auto value : filter_resources(values)) { + if (alias_analysis.IsUnknownResource(value)) return false; + const auto& ids = alias_analysis.GetResourceUniqueIds(value); + resource_ids_by_value.push_back({value, &ids}); + } + return true; + }; + + if (collect_ids(op->getOperands()) && collect_ids(op->getResults())) + return resource_ids_by_value; + else + return llvm::None; +} + +// Returns true if `op` is known to not have any side effect. +bool OpIsKnownToHaveNoSideEffect(Operation* op) { + // Note: Identity op is really side-effect free, but it is not marked as such + // in the TF dialect (see comments in definition of Identity op in tf_ops.td) + // However, for adding control dependencies, its safe to assume + // that the Identity op is side-effect free. + if (isa(op)) return true; + + // For op's in the Tensorflow dialect, query the dialect. + if (op->getName().getDialect() == + TF::TensorFlowDialect::getDialectNamespace()) + return !TensorFlowDialect::CanHaveSideEffects(op); + + // Otherwise, conservatively assume that there can be side effects. + return false; } } // namespace @@ -272,17 +342,17 @@ void SideEffectAnalysisInfo::AnalyzeRegion( if (OpIsDeclaration(&op, alias_analysis)) continue; auto resource_op_info = GetResourceInfoForOp(&op); - if (!resource_op_info && OpIsKnownToHaveNoSideEffect(&op)) continue; + if (resource_op_info.empty() && OpIsKnownToHaveNoSideEffect(&op)) + continue; - llvm::SmallDenseSet resources = - resource_op_info ? FindAccessedResources(&op, alias_analysis) - : UnknownResourceSet(); - assert(!resources.empty()); - const bool is_unknown = resources.count(kUnknownResourceId) > 0; - const bool read_only = OpIsReadOnly(&op); + if (IsResourceOpAllocOnly(&op, resource_op_info)) continue; + + auto resource_ids_by_value = + GetResourceIdsByValue(&op, alias_analysis, resource_op_info); + const bool read_only = IsResourceOpReadOnly(&op, resource_op_info); bool indirectly_tracked_unknown_access = false; // First add edges from known resources. - if (is_unknown) { + if (!resource_ids_by_value.hasValue()) { for (auto& entry : per_resource_access_info_) { if (entry.getFirst() == kUnknownResourceId) continue; AddPredecessorsForAccess(entry.getFirst(), &op, read_only); @@ -291,20 +361,43 @@ void SideEffectAnalysisInfo::AnalyzeRegion( read_only); } } else { - for (int64_t resource : resources) { - AddPredecessorsForAccess(resource, &op, read_only); + // Collect all resource id's and whether their side effect is read only. + llvm::SmallDenseMap read_only_by_resource_id; + for (const auto& resource_ids : *resource_ids_by_value) { + const bool is_result = resource_ids.first.getDefiningOp() == &op; + auto value_resource_info = resource_op_info.find(resource_ids.first); + bool resource_read_only = false; + if (value_resource_info != resource_op_info.end()) { + if (is_result && value_resource_info->getSecond().IsAllocOnly()) + continue; + resource_read_only = value_resource_info->getSecond().IsReadOnly(); + } + + for (const auto& id : *resource_ids.second) { + auto it = + read_only_by_resource_id.try_emplace(id, resource_read_only); + if (!it.second && !resource_read_only) + it.first->getSecond() = resource_read_only; + } + } + + for (const auto& resource : read_only_by_resource_id) { + const auto& resource_id = resource.getFirst(); + const auto& resource_read_only = resource.getSecond(); + AddPredecessorsForAccess(resource_id, &op, resource_read_only); indirectly_tracked_unknown_access |= - unknown_access_indirectly_tracked_by_resource(resource, - read_only); + unknown_access_indirectly_tracked_by_resource(resource_id, + resource_read_only); // Update access info for known resources. - TrackAccess(resource, &op, read_only); + TrackAccess(resource_id, &op, resource_read_only); } } + // If not indirectly tracked, add edges from the unknown resource. if (!indirectly_tracked_unknown_access) { AddPredecessorsForAccess(kUnknownResourceId, &op, read_only); } - if (is_unknown) { + if (!resource_ids_by_value.hasValue()) { // Update access info for unknown resource. TrackAccess(kUnknownResourceId, &op, read_only); } @@ -339,10 +432,7 @@ SideEffectAnalysisInfo::DirectControlSuccessors( } } // namespace detail -SideEffectAnalysis::SideEffectAnalysis(Operation* op) { - auto module = dyn_cast(op); - assert(module); - +SideEffectAnalysis::SideEffectAnalysis(ModuleOp module) { // Analyze entire module for alias analysis info. ResourceAliasAnalysis alias_analysis(module); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index c92c6e1882c..a75f7eb7dee 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -130,7 +130,7 @@ class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis< detail::SideEffectAnalysisInfo> { public: // Constructs analysis by analyzing the given module operation. - explicit SideEffectAnalysis(Operation* op); + explicit SideEffectAnalysis(ModuleOp module); }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 801e35280d6..5c6f39699bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -2,7 +2,6 @@ load( "//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library", - "tfe_xla_copts", ) package( @@ -20,7 +19,7 @@ tf_cuda_library( srcs = [ "c_api_unified_experimental_mlir.cc", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), deps = [ "//tensorflow/c:c_api", "//tensorflow/c:tensor_interface", @@ -41,6 +40,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 66447995709..6bfe4c302cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/raw_ostream.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -64,21 +66,18 @@ using tensorflow::AbstractTensorInterface; using tensorflow::dyn_cast; using tensorflow::OutputList; using tensorflow::string; +using tensorflow::errors::FailedPrecondition; +using tensorflow::errors::InvalidArgument; +using tensorflow::errors::Unimplemented; using tensorflow::tracing::TracingContext; using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; namespace { -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +void RegisterDialects(mlir::MLIRContext& ctx) { + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); } Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, @@ -95,7 +94,7 @@ class MlirTensor : public TracingTensorHandle { tensorflow::DataType DataType() const override { tensorflow::DataType type; - Status s = ConvertScalarTypeToDataType(value_.getType(), &type); + Status s = ConvertToDataType(value_.getType(), &type); if (!s.ok()) { return tensorflow::DT_INVALID; } @@ -103,6 +102,9 @@ class MlirTensor : public TracingTensorHandle { } Value getValue() { return value_; } + Type getElementType() { + return value_.getType().cast().getElementType(); + } // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { @@ -184,11 +186,18 @@ class MlirAbstractOp : public TracingOperation { } private: + // Return true is there are still unfilled ODS slots for adding more inputs. + bool IsNextODSArgAvailable(); + MLIRContext* context_; MlirFunctionContext* function_context_; SmallVector operands_; llvm::StringMap attrs_; std::unique_ptr state_; + // This is the index of the next ODS operand that will be added with AddInput + // or AddInput; + int current_ods_input_ = 0; + const tensorflow::OpDef* op_def_ = nullptr; const char* op_name_ = nullptr; string tf_op_type_; // TODO(srbs): Use this. @@ -225,6 +234,7 @@ class MlirFunctionContext : public TracingContext { : TracingContext(kMlir), context_(std::make_unique()), builder_(context_.get()) { + RegisterDialects(*context_); // TODO(aminim) figure out the location story here module_ = ModuleOp::create(builder_.getUnknownLoc()); func_ = FuncOp::create(builder_.getUnknownLoc(), name, @@ -244,12 +254,12 @@ class MlirFunctionContext : public TracingContext { Status Finalize(OutputList* outputs, AbstractFunction** f) override; Status RegisterFunction(AbstractFunction* func) override { - return tensorflow::errors::Unimplemented( + return Unimplemented( "Registering graph functions has not been implemented yet."); } Status RemoveFunction(const string& func) override { - return tensorflow::errors::Unimplemented( + return Unimplemented( "MlirFunctionContext::RemoveFunction has not been implemented yet."); } @@ -264,9 +274,12 @@ class MlirFunctionContext : public TracingContext { Status MlirAbstractOp::Reset(const char* op, const char* device_name) { if (state_) { - return tensorflow::errors::FailedPrecondition( - "Reset called on already built op."); + return FailedPrecondition("Reset called on already built op."); } + TF_RETURN_IF_ERROR( + tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_)); + assert(op_def_); + tf_op_type_ = op; std::string name = "tf."; name += op; @@ -277,13 +290,12 @@ Status MlirAbstractOp::Reset(const char* op, const char* device_name) { Status MlirAbstractOp::SetAttrType(const char* attr_name, tensorflow::DataType dtype) { - if (!state_) { - return Status(tensorflow::error::Code::FAILED_PRECONDITION, - "op_type must be specified before specifying attrs."); - } + if (!state_) + return FailedPrecondition( + "op_type must be specified before specifying attrs."); Type mlir_type; Builder builder(context_); - TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type)); + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type)); attrs_[attr_name] = TypeAttr::get(mlir_type); return Status::OK(); } @@ -291,8 +303,7 @@ Status MlirAbstractOp::SetAttrType(const char* attr_name, Status MlirAbstractOp::SetOpName(const char* const op_name) { // TODO(aminim): should we use a location? if (op_name_) { - return tensorflow::errors::FailedPrecondition( - "SetOpName called on already built op."); + return FailedPrecondition("SetOpName called on already built op."); } op_name_ = op_name; return Status::OK(); @@ -301,8 +312,7 @@ Status MlirAbstractOp::SetOpName(const char* const op_name) { Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Type elt_type = getElementTypeOrSelf(type); if (elt_type.isa()) { - return tensorflow::errors::InvalidArgument( - "Requested reference to a reference type"); + return InvalidArgument("Requested reference to a reference type"); } elt_type = TensorFlowRefType::get(elt_type); if (RankedTensorType tensor_type = type.dyn_cast()) { @@ -315,138 +325,97 @@ Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Status MlirAbstractOp::Create(ArrayRef operands, OperationState** state) { state_->operands = llvm::to_vector<4>(operands); - const tensorflow::OpDef* op_def; - auto node_name = state_->name.getStringRef().drop_front( - TensorFlowDialect::getDialectNamespace().size() + 1); - TF_RETURN_IF_ERROR( - tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def)); Builder builder(context_); - // Process operands according to the op_def and infer derived attributes. - int current_operand = 0; - for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { - if (!input_arg.number_attr().empty()) { - // TODO(b/156122856): we don't support variadic operands. - return tensorflow::errors::Unimplemented( - "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); - } else if (!input_arg.type_list_attr().empty()) { - return tensorflow::errors::InvalidArgument( - "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); - } - if (current_operand >= operands.size()) { - return tensorflow::errors::InvalidArgument("Missing operand for '", - input_arg.name(), "'"); - } - Type expected_type; - if (input_arg.type() != tensorflow::DT_INVALID) { - TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type)); - Type output_type; - if (input_arg.is_ref()) - TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type)); - expected_type = output_type; - } else { - expected_type = operands[current_operand].getType(); - } - if (!input_arg.type_attr().empty()) { - attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type); - } - ++current_operand; - } - for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) { + if (current_ods_input_ != op_def_->input_arg_size()) + return InvalidArgument(absl::StrCat("Mismatch in operands number: got ", + current_ods_input_, " expected ", + op_def_->input_arg_size(), " ; for op ", + state_->name.getStringRef().str())); + + // Process results according to the op_def and infer types for derived + // attributes. + for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) { int original_size = state_->types.size(); if (!output_arg.number_attr().empty()) { // Same type repeated "repeats" times. Attribute repeats_attr = attrs_[output_arg.number_attr()]; - if (!repeats_attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.number_attr(), - "' required for output list '", output_arg.name(), "'"); - } - if (!repeats_attr.isa()) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.number_attr(), - "' required for output list '", output_arg.name(), - "' isn't an integer"); - } + if (!repeats_attr) + return InvalidArgument("Missing attribute '", output_arg.number_attr(), + "' required for output list '", + output_arg.name(), "'"); + if (!repeats_attr.isa()) + return InvalidArgument("Attribute '", output_arg.number_attr(), + "' required for output list '", + output_arg.name(), "' isn't an integer"); int64_t repeats = repeats_attr.cast().getInt(); if (!output_arg.type_attr().empty()) { // Same type repeated "repeats" times. Attribute attr = attrs_[output_arg.type_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.type_attr(), - "' required for output '", output_arg.name(), "'"); - } + if (!attr) + return InvalidArgument("Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "'"); TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_attr(), "' required for output '", - output_arg.name(), "' isn't a type attribute"); - } + if (!type_attr) + return InvalidArgument("Attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "' isn't a type attribute"); for (int i = 0; i < repeats; ++i) - state_->types.push_back(type_attr.getType()); + state_->types.push_back(UnrankedTensorType::get(type_attr.getType())); } else if (output_arg.type() != tensorflow::DT_INVALID) { for (int i = 0; i < repeats; ++i) { Type type; TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(output_arg.type(), builder, &type)); + ConvertDataType(output_arg.type(), builder, &type)); state_->types.push_back(type); } } else { - return tensorflow::errors::InvalidArgument( - "Missing type or type_attr field in ", - output_arg.ShortDebugString()); + return InvalidArgument("Missing type or type_attr field in ", + output_arg.ShortDebugString()); } } else if (!output_arg.type_attr().empty()) { Attribute attr = attrs_[output_arg.type_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.type_attr(), - "' required for output '", output_arg.name(), "'"); - } + if (!attr) + return InvalidArgument("Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "'"); TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_attr(), "' required for output '", - output_arg.name(), "' isn't a type attribute"); - } - state_->types.push_back(type_attr.getValue()); + if (!type_attr) + return InvalidArgument("Attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "' isn't a type attribute"); + state_->types.push_back(UnrankedTensorType::get(type_attr.getValue())); } else if (!output_arg.type_list_attr().empty()) { // This is pointing to an attribute which is an array of types. Attribute attr = attrs_[output_arg.type_list_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( + if (!attr) + return InvalidArgument( "Missing attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "'"); - } ArrayAttr array_attr = attr.dyn_cast(); - if (!array_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_list_attr(), - "' required for output '", output_arg.name(), - "' isn't an array attribute"); - } + if (!array_attr) + return InvalidArgument("Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' isn't an array attribute"); for (Attribute attr : array_attr) { TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Array Attribute '", output_arg.type_list_attr(), - "' required for output '", output_arg.name(), - "' has a non-Type element"); - } - state_->types.push_back(type_attr.getValue()); + if (!type_attr) + return InvalidArgument("Array Attribute '", + output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' has a non-Type element"); + state_->types.push_back(UnrankedTensorType::get(type_attr.getValue())); } } else if (output_arg.type() != tensorflow::DT_INVALID) { Type type; Builder builder(context_); - TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(output_arg.type(), builder, &type)); + TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type)); state_->types.push_back(type); } else { - return tensorflow::errors::InvalidArgument("No type fields in ", - output_arg.ShortDebugString()); + return InvalidArgument("No type fields in ", + output_arg.ShortDebugString()); } if (output_arg.is_ref()) { // For all types that were added by this function call, make them refs. @@ -458,6 +427,7 @@ Status MlirAbstractOp::Create(ArrayRef operands, } } } + for (auto& it : attrs_) state_->addAttribute(it.first(), it.second); *state = state_.get(); return Status::OK(); } @@ -471,88 +441,68 @@ Status MlirAbstractOp::SetDeviceName(const char* name) { return Status::OK(); } -Status MlirAbstractOp::AddInputList( - absl::Span inputs) { - return tensorflow::errors::Unimplemented( - "AddInputList has not been implemented yet."); -} - Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data, size_t length) { - return tensorflow::errors::Unimplemented( - "SetAttrString has not been implemented yet."); + return Unimplemented("SetAttrString has not been implemented yet."); } Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) { - return tensorflow::errors::Unimplemented( - "SetAttrInt has not been implemented yet."); + return Unimplemented("SetAttrInt has not been implemented yet."); } Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) { - return tensorflow::errors::Unimplemented( - "SetAttrFloat has not been implemented yet."); + return Unimplemented("SetAttrFloat has not been implemented yet."); } Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) { - return tensorflow::errors::Unimplemented( - "SetAttrBool has not been implemented yet."); + attrs_[attr_name] = BoolAttr::get(value, context_); + return Status::OK(); } Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) { - return tensorflow::errors::Unimplemented( - "SetAttrShape has not been implemented yet."); + return Unimplemented("SetAttrShape has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunction(const char* attr_name, const AbstractOperation* value) { - return tensorflow::errors::Unimplemented( - "SetAttrFunction has not been implemented yet."); + return Unimplemented("SetAttrFunction has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name, const char* value, size_t length) { - return tensorflow::errors::Unimplemented( - "SetAttrFunctionName has not been implemented yet."); + return Unimplemented("SetAttrFunctionName has not been implemented yet."); } Status MlirAbstractOp::SetAttrTensor(const char* attr_name, AbstractTensorInterface* tensor) { - return tensorflow::errors::Unimplemented( - "SetAttrTensor has not been implemented yet."); + return Unimplemented("SetAttrTensor has not been implemented yet."); } Status MlirAbstractOp::SetAttrStringList(const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrStringList has not been implemented yet."); + return Unimplemented("SetAttrStringList has not been implemented yet."); } Status MlirAbstractOp::SetAttrFloatList(const char* attr_name, const float* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrFloatList has not been implemented yet."); + return Unimplemented("SetAttrFloatList has not been implemented yet."); } Status MlirAbstractOp::SetAttrIntList(const char* attr_name, const int64_t* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrIntList has not been implemented yet."); + return Unimplemented("SetAttrIntList has not been implemented yet."); } Status MlirAbstractOp::SetAttrTypeList(const char* attr_name, const tensorflow::DataType* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrTypeList has not been implemented yet."); + return Unimplemented("SetAttrTypeList has not been implemented yet."); } Status MlirAbstractOp::SetAttrBoolList(const char* attr_name, const unsigned char* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrBoolList has not been implemented yet."); + return Unimplemented("SetAttrBoolList has not been implemented yet."); } Status MlirAbstractOp::SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrShapeList has not been implemented yet."); + return Unimplemented("SetAttrShapeList has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunctionList( const char* attr_name, absl::Span values) { - return tensorflow::errors::Unimplemented( - "SetAttrFunctionList has not been implemented yet."); + return Unimplemented("SetAttrFunctionList has not been implemented yet."); } Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { @@ -604,28 +554,101 @@ Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype, } Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) { + if (current_ods_input_ >= op_def_->input_arg_size()) + return InvalidArgument( + absl::StrCat("More Input() (", current_ods_input_, ") calls than the ", + op_def_->input_arg_size(), " allowed input_args ; for op ", + state_->name.getStringRef().str())); + auto* operand = dyn_cast(input); - if (!operand) { - return tensorflow::errors::InvalidArgument( - "Unable to cast input to MlirTensor"); - } + if (!operand) return InvalidArgument("Unable to cast input to MlirTensor"); operands_.push_back(operand->getValue()); + + // Get the next ArgDef and use it to infer the derived attributes associated + // to this input. + const tensorflow::OpDef::ArgDef& arg_def = + op_def_->input_arg(current_ods_input_++); + Type expected_type; + if (arg_def.type() != tensorflow::DT_INVALID) { + Builder builder(context_); + TF_RETURN_IF_ERROR( + tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type)); + if (arg_def.is_ref()) { + Type output_type; + TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type)); + expected_type = output_type; + } + } else { + expected_type = cast(input)->getElementType(); + } + if (!arg_def.type_attr().empty()) + attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type); + return Status::OK(); } + +Status MlirAbstractOp::AddInputList( + absl::Span inputs) { + if (current_ods_input_ >= op_def_->input_arg_size()) + return InvalidArgument( + absl::StrCat("More Input() (", current_ods_input_, ") calls than the ", + op_def_->input_arg_size(), " allowed input_args")); + + for (AbstractTensorHandle* input : inputs) { + auto* operand = dyn_cast(input); + if (!operand) return InvalidArgument("Unable to cast input to MlirTensor"); + operands_.push_back(operand->getValue()); + } + + // Get the next ArgDef and use it to infer the derived attributes associated + // to this input. + const tensorflow::OpDef::ArgDef& arg_def = + op_def_->input_arg(current_ods_input_++); + if (!arg_def.number_attr().empty()) { + Builder builder(context_); + attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size()); + // TODO(aminim): handle ref variable. + if (arg_def.type() != tensorflow::DT_INVALID) { + // TODO(aminim): check type wrt input + Type arg_def_type; + TF_RETURN_IF_ERROR( + ConvertDataType(arg_def.type(), builder, &arg_def_type)); + // Ensure each of the type in the list matches the op def type. + // TODO(aminim): can we improve the error message with the actual types? + for (AbstractTensorHandle* input : inputs) + if (arg_def_type != cast(input)->getElementType()) + return InvalidArgument( + "Invalid input list: type mismatch the op def expectation"); + } else if (!inputs.empty()) { + if (arg_def.type_attr().empty()) + return FailedPrecondition( + "Invalid opdef type constraint: either type or type_attr required"); + + attrs_[arg_def.type_attr()] = + TypeAttr::get(cast(inputs.front())->getElementType()); + } + } else if (!arg_def.type_list_attr().empty()) { + // TODO(aminim): handle ref variable. + SmallVector types; + types.reserve(inputs.size()); + for (AbstractTensorHandle* input : inputs) + types.push_back(TypeAttr::get(cast(input)->getElementType())); + attrs_[arg_def.type_list_attr()] = ArrayAttr::get(types, GetContext()); + } + return Status::OK(); +} + Status MlirFunctionContext::Finalize(OutputList* outputs, AbstractFunction** f) { Block& body = func_.getBody().front(); SmallVector ret_operands; for (auto* output : outputs->outputs) { auto* operand = dyn_cast(output); - if (!operand) { - return tensorflow::errors::InvalidArgument( - "Capturing eager tensors is not supported yet."); - } - if (operand->getValue().getContext() != context_.get()) { - return tensorflow::errors::InvalidArgument( + if (!operand) + return InvalidArgument("Capturing eager tensors is not supported yet."); + if (operand->getValue().getContext() != context_.get()) + return InvalidArgument( "Capturing tensors from other context is not supported."); - } ret_operands.push_back(operand->getValue()); } builder_.create(func_.getLoc(), ret_operands); @@ -640,7 +663,6 @@ Status MlirFunctionContext::Finalize(OutputList* outputs, extern "C" { TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { - RegisterDialects(); return new MlirFunctionContext(fn_name); } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h similarity index 50% rename from tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc rename to tensorflow/compiler/mlir/tensorflow/dialect_registration.h index 45985cea583..a63bfd154ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { - -// Static initialization for TF dialect registration. -static DialectRegistration tf_ops; -static DialectRegistration - tf_executor_dialect; -static DialectRegistration - tf_device_dialect; -static DialectRegistration - tf_saved_model_dialect; -static DialectRegistration shape_dialect; - +// Inserts all the TensorFlow dialects in the provided registry. This is +// intended for tools that need to register dialects before parsing .mlir files. +inline void RegisterAllTensorFlowDialects(DialectRegistry ®istry) { + registry.insert(); +} } // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index dfad1fce26d..40cc2c99c27 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -74,12 +74,9 @@ struct FuncAttrStorage : public AttributeStorage { // Get or create a shape attribute. ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, llvm::Optional> shape) { - if (shape) - return Base::get(context, AttrKind::SHAPE, *shape, - /*unranked=*/false); + if (shape) return Base::get(context, *shape, /*unranked=*/false); - return Base::get(context, AttrKind::SHAPE, ArrayRef(), - /*unranked=*/true); + return Base::get(context, ArrayRef(), /*unranked=*/true); } llvm::Optional> ShapeAttr::getValue() const { @@ -112,12 +109,12 @@ bool ShapeAttr::hasStaticShape() const { FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, DictionaryAttr attr) { auto symbol = SymbolRefAttr::get(name, context); - return Base::get(context, AttrKind::FUNC, symbol, attr); + return Base::get(context, symbol, attr); } FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, DictionaryAttr attr) { - return Base::get(context, AttrKind::FUNC, symbol, attr); + return Base::get(context, symbol, attr); } SymbolRefAttr FuncAttr::GetName() const { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index 1edc7356ab4..5a18b77ab5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -24,19 +24,6 @@ limitations under the License. namespace mlir { namespace TF { -namespace AttrKind { - -// List of supported custom TensorFlow Attribute kinds, necessary for -// isa/dyn_cast. -enum Kind { - FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, - SHAPE = FIRST_USED_TENSORFLOW_ATTR, - FUNC, - LAST_USED_TENSORFLOW_ATTR, -}; - -} // namespace AttrKind - namespace detail { struct ShapeAttrStorage; @@ -70,8 +57,6 @@ class ShapeAttr : public Attribute::AttrBase= 0), it has static // shape. bool hasStaticShape() const; - - static bool kindof(unsigned kind) { return kind == AttrKind::SHAPE; } }; // Custom attribute to model AttrValue.value.func (NameAttrList type attribute). @@ -97,8 +82,6 @@ class FuncAttr SymbolRefAttr GetName() const; DictionaryAttr GetAttrs() const; - - static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; } }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 77008b55672..0e85582337d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -101,7 +101,8 @@ bool BlockWrapsSingleOp(Block* block) { } // end anonymous namespace TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) - : Dialect(/*name=*/"tf_device", context) { + : Dialect(/*name=*/"tf_device", context, + TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" @@ -118,31 +119,6 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) // operation results are perfectly forwarded to the launch return. bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); } -//===----------------------------------------------------------------------===// -// tf_device.return -//===----------------------------------------------------------------------===// - -namespace { -ParseResult ParseReturnOp(OpAsmParser* parser, OperationState* state) { - llvm::SmallVector op_info; - llvm::SmallVector types; - llvm::SMLoc loc = parser->getCurrentLocation(); - return failure(parser->parseOperandList(op_info) || - (!op_info.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(op_info, types, loc, state->operands)); -} - -void Print(ReturnOp op, OpAsmPrinter* p) { - *p << op.getOperationName(); - if (op.getNumOperands() > 0) { - *p << ' '; - p->printOperands(op.getOperands()); - *p << " : "; - interleaveComma(op.getOperandTypes(), *p); - } -} -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_device.parallel_execute //===----------------------------------------------------------------------===// @@ -393,7 +369,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) { // [%a, ...] as %block_arg0: type // packed_input // %b as %block_arg1: type - const int32_t n = op.n().getSExtValue(); + const int32_t n = op.n(); const int32_t num_replicated_inputs = (*op.operand_segment_sizes().int_value_begin()).getSExtValue(); const int32_t num_replicated_block_args = num_replicated_inputs / n; @@ -437,7 +413,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) { } LogicalResult Verify(ReplicateOp op) { - int32_t n = op.n().getSExtValue(); + int32_t n = op.n(); // Check number of devices, if set, matches `n`. if (op.devices().hasValue()) { @@ -694,12 +670,12 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +} // namespace tf_device +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" - -} // namespace tf_device -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index d1ca07d85a7..5b1d9711875 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -36,15 +36,16 @@ namespace tf_device { // XlaRun. class TensorFlowDeviceDialect : public Dialect { public: + static StringRef getDialectNamespace() { return "tf_device"; } // Constructing TensorFlowDevice dialect under an non-null MLIRContext. explicit TensorFlowDeviceDialect(MLIRContext* context); }; +} // namespace tf_device +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc" -} // namespace tf_device -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 565be63a74f..8f1cd6877e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -36,7 +36,7 @@ def TfDevice_Dialect : Dialect { XlaRun. }]; - let cppNamespace = "tf_device"; + let cppNamespace = "::mlir::tf_device"; } //===----------------------------------------------------------------------===// @@ -104,8 +104,7 @@ The `tf_device.return` operation terminates and returns values from a }]> ]; - let parser = [{ return Parse$cppClass(&parser, &result); }]; - let printer = [{ return Print(*this, &p); }]; + let assemblyFormat = "attr-dict ($results^ `:` type($results))?"; } def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index c18723b0982..f2d0a548420 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -54,9 +54,6 @@ namespace tf_executor { namespace { -using TF::DropRefType; -using TF::DropTypeSubTypes; - struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -75,9 +72,8 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowExecutorOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -92,14 +88,15 @@ struct TensorFlowExecutorOpFolderDialectInterface } // namespace TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) - : Dialect(/*name=*/"tf_executor", context) { + : Dialect(/*name=*/"tf_executor", context, + TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" >(); addInterfaces(); + TensorFlowExecutorDialectFoldInterface>(); addTypes(); } @@ -253,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { // tf_executor.fetch //===----------------------------------------------------------------------===// -namespace { - -void Print(FetchOp fetch, OpAsmPrinter &p) { - p << fetch.getOperationName(); - if (fetch.getNumOperands() > 0) { - p << ' '; - p.printOperands(fetch.operand_begin(), fetch.operand_end()); - p << " : "; - interleaveComma(fetch.getOperandTypes(), p); - } - p.printOptionalAttrDict(fetch.getAttrs()); -} - -ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes) - - ); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.island //===----------------------------------------------------------------------===// @@ -414,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { // tf_executor.yield //===----------------------------------------------------------------------===// -namespace { - -void Print(YieldOp yield, OpAsmPrinter &p) { - p << yield.getOperationName(); - if (yield.getNumOperands() > 0) { - p << ' '; - p.printOperands(yield.operand_begin(), yield.operand_end()); - p << " : "; - interleaveComma(yield.getOperandTypes(), p); - } - p.printOptionalAttrDict(yield.getAttrs()); -} - -ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_info; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(op_info) || - (!op_info.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(op_info, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes)); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.Switch //===----------------------------------------------------------------------===// @@ -550,8 +495,8 @@ LogicalResult Verify(SwitchNOp switchn) { << operand0_tensor_type << " vs " << output_tensor_type; } Type broadcasted_type = OpTrait::util::getBroadcastedType( - DropRefType(DropTypeSubTypes(operand0_tensor_type)), - DropRefType(DropTypeSubTypes(output_tensor_type))); + TF::DropRefAndSubTypes(operand0_tensor_type), + TF::DropRefAndSubTypes(output_tensor_type)); if (!broadcasted_type) { return switchn.emitOpError() << "expects data operand to be broadcastable with all output types" @@ -667,8 +612,8 @@ LogicalResult Verify(MergeOp merge) { << operand_tensor_ty << " vs " << output_tensor_ty; } Type broadcasted_type = OpTrait::util::getBroadcastedType( - DropRefType(DropTypeSubTypes(output_tensor_ty)), - DropRefType(DropTypeSubTypes(operand_tensor_ty))); + TF::DropRefAndSubTypes(output_tensor_ty), + TF::DropRefAndSubTypes(operand_tensor_ty)); if (!broadcasted_type) return merge.emitOpError() << "expects all operands to be broadcastable with output type" @@ -851,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) { return success(); } -void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " : " << next_iteration.getType(0); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, - OperationState &result) { - SmallVector types; - if (parser.parseColonTypeList(types)) return failure(); - - MLIRContext *context = parser.getBuilder().getContext(); - Type token_type = TokenType::get(context); - Type control_type = ControlType::get(context); - result.addTypes({types.front(), token_type, control_type}); - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -894,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) { return success(); } -void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " ["; - p.printOperand(next_iteration.getOperand(0)); - p << "] "; - p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); - p << " : " << next_iteration.getOperand(1).getType(); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSinkOp(OpAsmParser &parser, - OperationState &result) { - SmallVector op_infos; - llvm::SMLoc loc = parser.getCurrentLocation(); - - // First type is always the token consumed from the NextIteration.source - Type token_type = TokenType::get(parser.getBuilder().getContext()); - SmallVector types = {token_type}; - - if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) - return failure(); - - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size() - 2, control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -962,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { // tf_executor.ControlTrigger //===----------------------------------------------------------------------===// -namespace { - -void Print(ControlTriggerOp trigger, OpAsmPrinter &p) { - p << trigger.getOperationName() << ' '; - p.printOperands(trigger.getOperands()); - p.printOptionalAttrDict(trigger.getAttrs()); -} - -ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_infos; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(op_infos)) return failure(); - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size(), control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - // Single control as the only output - result.types.push_back(control_type); - return parser.parseOptionalAttrDict(result.attributes); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.LoopCond //===----------------------------------------------------------------------===// @@ -1249,12 +1121,12 @@ LogicalResult IslandOp::fold(llvm::ArrayRef operands, return success(); } +} // namespace tf_executor +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" - -} // namespace tf_executor -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index 3bb30f16c3d..2bc13556b4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -35,6 +35,7 @@ namespace tf_executor { class TensorFlowExecutorDialect : public Dialect { public: + static StringRef getDialectNamespace() { return "tf_executor"; } explicit TensorFlowExecutorDialect(MLIRContext *context); // Parses a type registered to this dialect. @@ -44,44 +45,23 @@ class TensorFlowExecutorDialect : public Dialect { void printType(Type type, DialectAsmPrinter &os) const override; }; -namespace TFTypes { -enum Kind { - Control = Type::FIRST_TENSORFLOW_EXECUTOR_TYPE, - Token, -}; -} // namespace TFTypes - // The Control type is a token-like value that models control dependencies from // TensorFlow graphs. class ControlType : public Type::TypeBase { public: using Base::Base; - - static ControlType get(MLIRContext *context) { - return Base::get(context, TFTypes::Control); - } - - // Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == TFTypes::Control; } }; class TokenType : public Type::TypeBase { public: using Base::Base; - - static TokenType get(MLIRContext *context) { - return Base::get(context, TFTypes::Token); - } - - // Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == TFTypes::Token; } }; +} // namespace tf_executor +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h.inc" -} // namespace tf_executor -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 3081018b8da..713ddc44cba 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -43,14 +43,16 @@ def TfExecutor_Dialect : Dialect { value). }]; - let cppNamespace = "tf_executor"; + let cppNamespace = "::mlir::tf_executor"; } // Control type. -def TfeControlType : Type()">, "control">; +def TfeControlType : Type()">, "control">, + BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">; +def TfeTokenType : Type()">, "token">, + BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands // and results. For example, MergeOp output type. @@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_IslandOp : TfExecutor_Op<"island", @@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", @@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", } }]; + let assemblyFormat = "`:` type($output) attr-dict"; + + let printer = ?; + let parser = ?; } @@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_ExitOp : TfExecutor_Op<"Exit", @@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", .Attr("T: type") For example: - %1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32> + %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"} Note: Additional result corresponds to the control output. }]; @@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = "$controlInputs attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 081903d13cf..ba9ba8ea248 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -52,6 +52,12 @@ an output element, this operation computes \\(y = |x|\\). def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes acos of x element-wise."; + let description = [{ +Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`. + + Input range is `[-1, 1]` and the output has a range of `[0, pi]`. + }]; + let arguments = (ins TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); @@ -87,29 +93,6 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x + y element-wise."; - - let description = [{ -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TF_NumberOrStrTensor:$x, - TF_NumberOrStrTensor:$y - ); - - let results = (outs - TF_NumberOrStrTensor:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; -} - def TF_AddNOp : TF_Op<"AddN", [Commutative, NoSideEffect]> { let summary = "Add all input tensors element wise."; @@ -136,31 +119,6 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x + y element-wise."; - - let description = [{ -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> { let summary = "Adjust the contrast of one or more images."; @@ -571,7 +529,7 @@ see the incremented value or a subsequent newer one. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -589,7 +547,7 @@ see the decremented value or a subsequent newer one. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -607,7 +565,7 @@ this value or a subsequent newer value of the variable. }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_Tensor:$value ); @@ -859,15 +817,15 @@ about broadcasting }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$y, DefaultValuedAttr:$adj_x, DefaultValuedAttr:$adj_y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -965,7 +923,41 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } -def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { +def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { + let summary = [{ +Compute the regularized incomplete beta integral \\(I_x(a, b)\\). + }]; + + let description = [{ +The regularized incomplete beta integral is defined as: + + +\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) + +where + + +\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) + + +is the incomplete beta function and \\(B(a, b)\\) is the *complete* +beta function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$b, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect, TF_ContractionFusableInterface]> { let summary = "Adds `bias` to `value`."; let description = [{ @@ -986,6 +978,11 @@ Broadcasting is supported, so `value` may have any number of dimensions. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; + let verifier = [{ return Verify(*this); }]; @@ -1319,6 +1316,7 @@ subsequent operation and then be optimized away, however.) let verifier = [{ return Verify(*this); }]; + let hasFolder = 1; } def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> { @@ -1350,48 +1348,6 @@ then the output will be TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CaseOp : TF_Op<"Case", []> { - let summary = [{ -An n-way switch statement which calls a single branch function. - }]; - - let description = [{ -An n-way switch statement, implementing the following: - ``` - switch (branch_index) { - case 0: - output = branches[0](input); - break; - case 1: - output = branches[1](input); - break; - ... - case [[nbranches-1]]: - default: - output = branches[nbranches-1](input); - break; - } - ``` - }]; - - let arguments = (ins - I32Tensor:$branch_index, - Variadic:$input, - - Confined]>:$branches, - DefaultValuedAttr:$output_shapes - ); - - let results = (outs - Variadic:$output - ); - - TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; - TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; - - let hasCanonicalizer = 1; -} - def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; @@ -1446,6 +1402,38 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CholeskyOp : TF_Op<"Cholesky", [NoSideEffect]> { + let summary = [{ +Computes the Cholesky decomposition of one or more square matrices. + }]; + + let description = [{ +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. + +The input has to be symmetric and positive definite. Only the lower-triangular +part of the input will be used for this operation. The upper-triangular part +will not be read. + +The output is a tensor of the same shape as the input +containing the Cholesky decompositions for all input submatrices `[..., :, :]`. + +**Note**: The gradient computation on GPU is faster for large matrices but +not for large batch dimensions when the submatrices are small. In this +case it might be faster to use the CPU. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Clips tensor values to a specified min and max."; @@ -1715,6 +1703,24 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let hasCanonicalizer = 1; } +def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> { + let summary = [{ +Sets up the centralized structures for a distributed TPU system. + }]; + + let arguments = (ins + StrAttr:$embedding_config, + StrAttr:$tpu_embedding_config, + DefaultValuedAttr:$is_global_init, + DefaultValuedAttr:$enable_whole_mesh_compilations, + DefaultValuedAttr:$compilation_failure_closes_chips + ); + + let results = (outs + TF_StrTensor:$topology + ); +} + def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns the complex conjugate of a complex number."; @@ -2067,17 +2073,73 @@ and `B, D, F, H` as group 1. Thus we get the outputs: }]; let arguments = (ins - TensorOf<[BF16, F32, I32, TF_Uint32]>:$input, + TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$input, I32Tensor:$group_assignment ); let results = (outs - TensorOf<[BF16, F32, I32, TF_Uint32]>:$output + TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { + let summary = [{ +Compute the cumulative product of the tensor `x` along `axis`. + }]; + + let description = [{ +By default, this op performs an inclusive cumprod, which means that the first +element of the input is identical to the first element of the output: + +```python +tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TF_I32OrI64Tensor:$axis, + + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; @@ -2126,6 +2188,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { @@ -2151,6 +2217,82 @@ the source data format. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Permute input tensor from `src_format` to `dst_format`."; + + let description = [{ +Input tensor must be a vector of size 4, or a 4x2 tensor. + +For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and inputs: +``` +[1, 2, 3, 4] +``` +and +``` +[[1, 2, 3, 4], + [5, 6, 7, 8]] +``` +, the outputs will be (respectively): +``` +[1, 4, 2, 3] +``` +and +``` +[[1, 4, 2, 3], + [5, 8, 6, 7]] +``` + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$x, + + DefaultValuedAttr:$src_format, + DefaultValuedAttr:$dst_format + ); + + let results = (outs + TF_I32OrI64Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; +} + +def TF_DebugIdentityV2Op : TF_Op<"DebugIdentityV2", []> { + let summary = "Debug Identity V2 Op."; + + let description = [{ +Provides an identity mapping from input to output, while writing the content of +the input tensor by calling DebugEventsWriter. + +The semantics of the input tensor depends on tensor_debug_mode. In typical +usage, the input tensor comes directly from the user computation only when +graph_debug_mode is FULL_TENSOR (see protobuf/debug_event.proto for a +list of all the possible values of graph_debug_mode). For the other debug modes, +the input tensor should be produced by an additional op or subgraph that +computes summary information about one or more tensors. + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$tfdbg_context_id, + StrAttr:$op_name, + DefaultValuedAttr:$output_slot, + DefaultValuedAttr:$tensor_debug_mode, + DefaultValuedAttr:$debug_urls, + DefaultValuedAttr:$circular_buffer_size, + StrAttr:$tfdbg_run_id + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> { let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor."; @@ -2444,6 +2586,54 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DepthwiseConv2dNativeBackpropFilterOp : TF_Op<"DepthwiseConv2dNativeBackpropFilter", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the filter. + }]; + + let arguments = (ins + TF_FpTensor:$input, + I32Tensor:$filter_sizes, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_DepthwiseConv2dNativeBackpropInputOp : TF_Op<"DepthwiseConv2dNativeBackpropInput", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the input. + }]; + + let arguments = (ins + I32Tensor:$input_sizes, + TF_FpTensor:$filter, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -2463,6 +2653,40 @@ this op runs. The length of the list is returned in two cases: ); } +def TF_DiagOp : TF_Op<"Diag", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Returns a diagonal tensor with a given diagonal values."; + + let description = [{ +Given a `diagonal`, this operation returns a tensor with the `diagonal` and +everything else padded with zeros. The diagonal is computed as follows: + +Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: + +`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. + +For example: + +``` +# 'diagonal' is [1, 2, 3, 4] +tf.diag(diagonal) ==> [[1, 0, 0, 0] + [0, 2, 0, 0] + [0, 0, 3, 0] + [0, 0, 0, 4]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> { let summary = "Returns the diagonal part of the tensor."; @@ -2543,27 +2767,6 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe let hasFolder = 1; } -def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns 0 if the denominator is zero."; - - let description = [{ -*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y - ); - - let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> { let summary = [{ Interleave the values from the `data` tensors into a single tensor. @@ -3117,6 +3320,27 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ExtractImagePatchesOp : TF_Op<"ExtractImagePatches", [NoSideEffect]> { + let summary = [{ +Extract `patches` from `images` and put them in the "depth" output dimension. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images, + + Confined]>:$ksizes, + Confined]>:$strides, + Confined]>:$rates, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> { let summary = "Fast Fourier transform."; @@ -3589,6 +3813,95 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } +def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { let summary = "Gather slices from `params` according to `indices`."; @@ -3922,7 +4235,7 @@ table will be immutable. ); let results = (outs - TF_ResourceTensor:$table_handle + Res:$table_handle ); } @@ -4227,33 +4540,37 @@ tf.imag(input) ==> [4.75, 5.75] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { - let summary = "Initializes a table from a text file."; - - let description = [{ -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. +def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> { + let summary = [{ +A placeholder op for a value that will be fed into the computation. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, - TF_StrTensor:$filename, + TF_ShapeAttr:$shape + ); - Confined]>:$key_index, - Confined]>:$value_index, - Confined, [IntMinValue<-1>]>:$vocab_size, - DefaultValuedAttr:$delimiter + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_InitializeTableV2Op : TF_Op<"InitializeTableV2", []> { + let summary = [{ +Table initializer that takes two tensors for keys and values respectively. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_Tensor:$keys, + TF_Tensor:$values ); let results = (outs); + + TF_DerivedOperandTypeAttr Tval = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>; } def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { @@ -4563,7 +4880,7 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> { +def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface]> { let summary = "Computes rectified linear: `max(features, features * alpha)`."; let arguments = (ins @@ -4579,6 +4896,11 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasFolder = 1; + + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; } def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> { @@ -4772,6 +5094,49 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<2>; } +def TF_ListDiffOp : TF_Op<"ListDiff", [NoSideEffect]> { + let summary = [{ +Computes the difference between two lists of numbers or strings. + }]; + + let description = [{ +Given a list `x` and a list `y`, this operation returns a list `out` that +represents all values that are in `x` but not in `y`. The returned list `out` +is sorted in the same order that the numbers appear in `x` (duplicates are +preserved). This operation also returns a list `idx` that represents the +position of each `out` element in `x`. In other words: + +`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` + +For example, given this input: + +``` +x = [1, 2, 3, 4, 5, 6] +y = [1, 3, 5] +``` + +This operation would return: + +``` +out ==> [2, 4, 6] +idx ==> [1, 3, 5] +``` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_Tensor:$y + ); + + let results = (outs + TF_Tensor:$out, + TF_I32OrI64Tensor:$idx + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>; +} + def TF_LogOp : TF_Op<"Log", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes natural logarithm of x element-wise."; @@ -4896,6 +5261,22 @@ def TF_LogicalOrOp : TF_Op<"LogicalOr", [Commutative, NoSideEffect, ResultsBroad ); } +def TF_LookupTableExportV2Op : TF_Op<"LookupTableExportV2", []> { + let summary = "Outputs all keys and values in the table."; + + let arguments = (ins + Arg:$table_handle + ); + + let results = (outs + TF_Tensor:$keys, + TF_Tensor:$values + ); + + TF_DerivedResultTypeAttr Tkeys = TF_DerivedResultTypeAttr<0>; + TF_DerivedResultTypeAttr Tvalues = TF_DerivedResultTypeAttr<1>; +} + def TF_LookupTableFindV2Op : TF_Op<"LookupTableFindV2", []> { let summary = "Looks up keys in a table, outputs the corresponding values."; @@ -4908,7 +5289,7 @@ table. It must also be of the same type as the table values. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, + Arg:$table_handle, TF_Tensor:$keys, TF_Tensor:$default_value ); @@ -4932,7 +5313,7 @@ The tensor `values` must be of the type of the table values. }]; let arguments = (ins - TF_ResourceTensor:$table_handle, + Arg:$table_handle, TF_Tensor:$keys, TF_Tensor:$values ); @@ -4943,6 +5324,44 @@ The tensor `values` must be of the type of the table values. TF_DerivedOperandTypeAttr Tout = TF_DerivedOperandTypeAttr<2>; } +def TF_LookupTableInsertV2Op : TF_Op<"LookupTableInsertV2", []> { + let summary = "Updates the table to associates keys with values."; + + let description = [{ +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_Tensor:$keys, + TF_Tensor:$values + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr Tout = TF_DerivedOperandTypeAttr<2>; +} + +def TF_LookupTableRemoveV2Op : TF_Op<"LookupTableRemoveV2", []> { + let summary = "Removes keys and its associated values from a table."; + + let description = [{ +The tensor `keys` must of the same type as the keys of the table. Keys not +already in the table are silently ignored. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_Tensor:$keys + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<1>; +} + def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { let summary = "Computes the number of elements in the given table."; @@ -4955,6 +5374,44 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); } +def TF_LowerBoundOp : TF_Op<"LowerBound", [NoSideEffect]> { + let summary = [{ +Applies lower_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='left')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = LowerBound(sorted_sequence, values) + + result == [[1, 2, 2], + [0, 1, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -5464,6 +5921,36 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixInverseOp : TF_Op<"MatrixInverse", [NoSideEffect]> { + let summary = [{ +Computes the inverse of one or more square invertible matrices or their adjoints (conjugate transposes). + }]; + + let description = [{ +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. The output is a tensor of the same shape as the input +containing the inverse for all input submatrices `[..., :, :]`. + +The op uses LU decomposition with partial pivoting to compute the inverses. + +If a matrix is not invertible there is no guarantee what the op does. It +may detect the condition and raise an exception or it may simply return a +garbage result. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. @@ -5715,6 +6202,100 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixSolveOp : TF_Op<"MatrixSolve", [NoSideEffect]> { + let summary = "Solves systems of linear equations."; + + let description = [{ +`Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +If `adjoint` is `True` then each output matrix satisfies +`adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> { + let summary = [{ +Solves systems of linear equations with upper or lower triangular matrices by backsubstitution. + }]; + + let description = [{ +`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +square matrices. If `lower` is `True` then the strictly upper triangular part +of each inner-most matrix is assumed to be zero and not accessed. +If `lower` is False then the strictly lower triangular part of each inner-most +matrix is assumed to be zero and not accessed. +`rhs` is a tensor of shape `[..., M, N]`. + +The output is a tensor of shape `[..., M, N]`. If `adjoint` is +`True` then the innermost matrices in `output` satisfy matrix equations +`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +If `adjoint` is `False` then the strictly then the innermost matrices in +`output` satisfy matrix equations +`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. + +Note, the batch shapes for the inputs only need to broadcast. + +Example: +```python + +a = tf.constant([[3, 0, 0, 0], + [2, 1, 0, 0], + [1, 0, 1, 0], + [1, 1, 1, 1]], dtype=tf.float32) + +b = tf.constant([[4], + [2], + [4], + [2]], dtype=tf.float32) + +x = tf.linalg.triangular_solve(a, b, lower=True) +x +# + +# in python3 one can use `a@x` +tf.matmul(a, x) +# +``` + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + + DefaultValuedAttr:$lower, + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MaxOp : TF_Op<"Max", [NoSideEffect]> { let summary = [{ Computes the maximum of elements across dimensions of a tensor. @@ -5728,14 +6309,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5755,7 +6336,8 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter Confined]>:$ksize, Confined]>:$strides, - TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, DefaultValuedAttr, "NHWC">:$data_format ); @@ -5824,7 +6406,8 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { Confined]>:$ksize, Confined]>:$strides, - TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, DefaultValuedAttr:$data_format ); @@ -5839,25 +6422,60 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { }]; } -def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; +def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { + let summary = "Computes the mean of elements across dimensions of a tensor."; let description = [{ -*NOTE*: `Maximum` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TF_I32OrI64Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; +} + +def TF_MergeSummaryOp : TF_Op<"MergeSummary", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Merges summaries."; + + let description = [{ +This op creates a +[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +protocol buffer that contains the union of all the values in the input +summaries. + +When the Op is run, it reports an `InvalidArgument` error if multiple values +in the summaries to merge use the same tag. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + TF_StrTensor:$summary + ); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } def TF_MergeV2CheckpointsOp : TF_Op<"MergeV2Checkpoints", []> { @@ -5899,14 +6517,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5981,7 +6599,7 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] } def TF_MlirLocalVarOp : TF_Op<"MlirLocalVarOp", []> { - let summary = "Creates a handle to a in-scope variable."; + let summary = "Creates a handle to an in-scope variable."; let description = [{ Used by internal passes for temporary representation of local state, which will @@ -5991,7 +6609,7 @@ be eventually removed. let arguments = (ins); let results = (outs - TF_ResourceTensor:$resource + Res:$resource ); } @@ -6072,7 +6690,7 @@ the result here is consistent with a truncating divide. E.g. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, +def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -6137,6 +6755,85 @@ def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> { TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_MutableDenseHashTableV2Op : TF_Op<"MutableDenseHashTableV2", []> { + let summary = [{ +Creates an empty hash table that uses tensors as the backing store. + }]; + + let description = [{ +It uses "open addressing" with quadratic reprobing to resolve +collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + TF_Tensor:$empty_key, + TF_Tensor:$deleted_key, + + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$value_dtype, + DefaultValuedAttr({})">:$value_shape, + DefaultValuedAttr:$initial_num_buckets, + DefaultValuedAttr:$max_load_factor + ); + + let results = (outs + Res:$table_handle + ); + + TF_DerivedOperandTypeAttr key_dtype = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MutableHashTableOfTensorsV2Op : TF_Op<"MutableHashTableOfTensorsV2", []> { + let summary = "Creates an empty hash table."; + + let description = [{ +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$key_dtype, + TypeAttr:$value_dtype, + DefaultValuedAttr({})">:$value_shape + ); + + let results = (outs + Res:$table_handle + ); +} + +def TF_MutableHashTableV2Op : TF_Op<"MutableHashTableV2", []> { + let summary = "Creates an empty hash table."; + + let description = [{ +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + }]; + + let arguments = (ins + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$key_dtype, + TypeAttr:$value_dtype + ); + + let results = (outs + Res:$table_handle + ); +} + def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { let summary = ""; @@ -7233,9 +7930,6 @@ def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> { Creates a dataset with a range of values. Corresponds to python's xrange. }]; - let description = [{ - }]; - let arguments = (ins I64Tensor:$start, I64Tensor:$stop, @@ -7340,33 +8034,6 @@ tf.real(input) ==> [-2.25, 3.25] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x / y element-wise for real types."; - - let description = [{ -If `x` and `y` are reals, this will return the floating-point division. - -*NOTE*: `Div` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the reciprocal of x element-wise."; @@ -7430,7 +8097,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph. TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; } -def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { +def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; let description = [{ @@ -7449,6 +8116,11 @@ array([ 0., 0., -0., 3.], dtype=float32) ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + // TF_ContractionFusableInterface: + Optional GetContractionFusion(); + }]; } def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { @@ -7645,6 +8317,105 @@ Resize `images` to `size` using nearest neighbor interpolation. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResourceApplyAdaMaxOp : TF_Op<"ResourceApplyAdaMax", []> { + let summary = "Update '*var' according to the AdaMax algorithm."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +v_t <- max(beta2 * v_{t-1}, abs(g)) +variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + Arg:$v, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyAdadeltaOp : TF_Op<"ResourceApplyAdadelta", []> { + let summary = "Update '*var' according to the adadelta scheme."; + + let description = [{ +accum = rho() * accum + (1 - rho()) * grad.square(); +update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +update_accum = rho() * update_accum + (1 - rho()) * update.square(); +var -= update; + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$accum_update, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyAdagradOp : TF_Op<"ResourceApplyAdagrad", []> { + let summary = "Update '*var' according to the adagrad scheme."; + + let description = [{ +accum += grad * grad +var -= lr * grad * (1 / sqrt(accum)) + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$update_slots + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyAdagradDAOp : TF_Op<"ResourceApplyAdagradDA", []> { + let summary = "Update '*var' according to the proximal adagrad scheme."; + + let arguments = (ins + Arg:$var, + Arg:$gradient_accumulator, + Arg:$gradient_squared_accumulator, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + I64Tensor:$global_step, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> { let summary = "Update '*var' according to the adagrad scheme."; @@ -7654,8 +8425,8 @@ var -= lr * grad * (1 / (sqrt(accum) + epsilon)) }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, + Arg:$var, + Arg:$accum, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, @@ -7680,9 +8451,9 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$m, - TF_ResourceTensor:$v, + Arg:$var, + Arg:$m, + Arg:$v, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, @@ -7700,6 +8471,32 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; } +def TF_ResourceApplyAddSignOp : TF_Op<"ResourceApplyAddSign", []> { + let summary = "Update '*var' according to the AddSign update."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +update <- (alpha + sign_decay * sign(g) *sign(m)) * g +variable <- variable - lr_t * update + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$sign_decay, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> { let summary = "Update '*var' according to the centered RMSProp algorithm."; @@ -7725,10 +8522,10 @@ var <- var - mom }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$mg, - TF_ResourceTensor:$ms, - TF_ResourceTensor:$mom, + Arg:$var, + Arg:$mg, + Arg:$ms, + Arg:$mom, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, @@ -7743,11 +8540,74 @@ var <- var - mom TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>; } +def TF_ResourceApplyFtrlOp : TF_Op<"ResourceApplyFtrl", []> { + let summary = "Update '*var' according to the Ftrl-proximal scheme."; + + let description = [{ +accum_new = accum + grad * grad +linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +accum = accum_new + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$linear, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr_power, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$multiply_linear_by_lr + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_ResourceApplyFtrlV2Op : TF_Op<"ResourceApplyFtrlV2", []> { + let summary = "Update '*var' according to the Ftrl-proximal scheme."; + + let description = [{ +grad_with_shrinkage = grad + 2 * l2_shrinkage * var +accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +linear += grad_with_shrinkage + + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +accum = accum_new + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + Arg:$linear, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2_shrinkage, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr_power, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$multiply_linear_by_lr + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; let arguments = (ins - TF_ResourceTensor:$var, + Arg:$var, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$delta, @@ -7770,8 +8630,8 @@ var += accum }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, + Arg:$var, + Arg:$accum, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, @@ -7796,8 +8656,8 @@ var -= lr * accum }]; let arguments = (ins - TF_ResourceTensor:$var, - TF_ResourceTensor:$accum, + Arg:$var, + Arg:$accum, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, @@ -7811,6 +8671,116 @@ var -= lr * accum TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; } +def TF_ResourceApplyPowerSignOp : TF_Op<"ResourceApplyPowerSign", []> { + let summary = "Update '*var' according to the AddSign update."; + + let description = [{ +m_t <- beta1 * m_{t-1} + (1 - beta1) * g +update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +variable <- variable - lr_t * update + }]; + + let arguments = (ins + Arg:$var, + Arg:$m, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$logbase, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$sign_decay, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyProximalAdagradOp : TF_Op<"ResourceApplyProximalAdagrad", []> { + let summary = [{ +Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. + }]; + + let description = [{ +accum += grad * grad +prox_v = var - lr * grad * (1 / sqrt(accum)) +var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} + }]; + + let arguments = (ins + Arg:$var, + Arg:$accum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceApplyProximalGradientDescentOp : TF_Op<"ResourceApplyProximalGradientDescent", []> { + let summary = "Update '*var' as FOBOS algorithm with fixed learning rate."; + + let description = [{ +prox_v = var - alpha * delta +var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} + }]; + + let arguments = (ins + Arg:$var, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$delta, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + +def TF_ResourceApplyRMSPropOp : TF_Op<"ResourceApplyRMSProp", []> { + let summary = "Update '*var' according to the RMSProp algorithm."; + + let description = [{ +Note that in dense implementation of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implementation, ms +and mom will not update in iterations during which the grad is zero. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +var <- var - mom + }]; + + let arguments = (ins + Arg:$var, + Arg:$ms, + Arg:$mom, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceGatherOp : TF_Op<"ResourceGather", []> { let summary = [{ Gather slices from the variable pointed to by `resource` according to `indices`. @@ -7833,7 +8803,7 @@ Produces an output tensor with shape `indices.shape + params.shape[1:]` where: }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_I32OrI64Tensor:$indices, DefaultValuedAttr:$batch_dims, @@ -7848,6 +8818,405 @@ Produces an output tensor with shape `indices.shape + params.shape[1:]` where: TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_ResourceScatterAddOp : TF_Op<"ResourceScatterAdd", []> { + let summary = "Adds sparse updates to the variable referenced by `resource`."; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] += updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] += updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterDivOp : TF_Op<"ResourceScatterDiv", []> { + let summary = [{ +Divides sparse updates into the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] /= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] /= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMaxOp : TF_Op<"ResourceScatterMax", []> { + let summary = [{ +Reduces sparse updates into the variable referenced by `resource` using the `max` operation. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] = max(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMinOp : TF_Op<"ResourceScatterMin", []> { + let summary = [{ +Reduces sparse updates into the variable referenced by `resource` using the `min` operation. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] = min(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterMulOp : TF_Op<"ResourceScatterMul", []> { + let summary = [{ +Multiplies sparse updates into the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] *= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] *= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdAddOp : TF_Op<"ResourceScatterNdAdd", []> { + let summary = [{ +Applies sparse addition to individual values or slices in a Variable. + }]; + + let description = [{ +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +``` + +For example, say we want to add 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that addition would look like this: + +```python +ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +indices = tf.constant([[4], [3], [1], [7]]) +updates = tf.constant([9, 10, 11, 12]) +add = tf.scatter_nd_add(ref, indices, updates) +with tf.Session() as sess: + print sess.run(add) +``` + +The resulting update to ref would look like this: + + [1, 13, 3, 14, 14, 6, 7, 20] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdSubOp : TF_Op<"ResourceScatterNdSub", []> { + let summary = [{ +Applies sparse subtraction to individual values or slices in a Variable. + }]; + + let description = [{ +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +``` + +For example, say we want to subtract 4 scattered elements from a rank-1 tensor +with 8 elements. In Python, that subtraction would look like this: + +```python +ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +indices = tf.constant([[4], [3], [1], [7]]) +updates = tf.constant([9, 10, 11, 12]) +sub = tf.scatter_nd_sub(ref, indices, updates) +with tf.Session() as sess: + print sess.run(sub) +``` + +The resulting update to ref would look like this: + + [1, -9, 3, -6, -4, 6, 7, -4] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterNdUpdateOp : TF_Op<"ResourceScatterNdUpdate", []> { + let summary = [{ +Applies sparse `updates` to individual values or slices within a given + }]; + + let description = [{ +variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to update 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that update would look like this: + +```python + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) +``` + +The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + +See `tf.scatter_nd` for more details about how to make updates to +slices. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_ResourceScatterSubOp : TF_Op<"ResourceScatterSub", []> { + let summary = [{ +Subtracts sparse updates from the variable referenced by `resource`. + }]; + + let description = [{ +This operation computes + + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +
+ +
+ }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + def TF_ResourceScatterUpdateOp : TF_Op<"ResourceScatterUpdate", []> { let summary = [{ Assigns sparse updates to the variable referenced by `resource`. @@ -7867,7 +9236,7 @@ This operation computes }]; let arguments = (ins - TF_ResourceTensor:$resource, + Arg:$resource, TF_I32OrI64Tensor:$indices, TF_Tensor:$updates ); @@ -7878,6 +9247,38 @@ This operation computes TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; } +def TF_ResourceStridedSliceAssignOp : TF_Op<"ResourceStridedSliceAssign", []> { + let summary = "Assign `value` to the sliced l-value reference of `ref`."; + + let description = [{ +The values of `value` are assigned to the positions in the variable +`ref` that are selected by the slice parameters. The slice parameters +`begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. + +NOTE this op currently does not support broadcasting and so `value`'s +shape must be exactly the shape produced by the slice of `ref`. + }]; + + let arguments = (ins + Arg:$ref, + TF_I32OrI64Tensor:$begin, + TF_I32OrI64Tensor:$end, + TF_I32OrI64Tensor:$strides, + TF_Tensor:$value, + + DefaultValuedAttr:$begin_mask, + DefaultValuedAttr:$end_mask, + DefaultValuedAttr:$ellipsis_mask, + DefaultValuedAttr:$new_axis_mask, + DefaultValuedAttr:$shrink_axis_mask + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>; + TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; +} + def TF_RestoreV2Op : TF_Op<"RestoreV2", []> { let summary = "Restores tensors from a V2 checkpoint."; @@ -8129,6 +9530,47 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> { + let summary = "Rolls the elements of a tensor along an axis."; + + let description = [{ +The elements are shifted positively (towards larger indices) by the offset of +`shift` along the dimension of `axis`. Negative `shift` values will shift +elements in the opposite direction. Elements that roll passed the last position +will wrap around to the first and vice versa. Multiple shifts along multiple +axes may be specified. + +For example: + +``` +# 't' is [0, 1, 2, 3, 4] +roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] + +# shifting along multiple dimensions +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] + +# shifting along the same axis multiple times +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$shift, + TF_I32OrI64Tensor:$axis + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tshift = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -8712,7 +10154,7 @@ This operation returns N 1-D integer tensors representing shape of `input[i]s`. return Verify(*this); }]; - let hasFolder = 1; + let hasCanonicalizer = 1; } def TF_ShardedFilenameOp : TF_Op<"ShardedFilename", [NoSideEffect]> { @@ -8735,6 +10177,18 @@ Generate a sharded filename. The filename is printf formatted as ); } +def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> { + let summary = "Shuts down a running distributed TPU system."; + + let description = [{ +The op returns an error if no system is running. + }]; + + let arguments = (ins); + + let results = (outs); +} + def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes sigmoid of `x` element-wise."; @@ -8876,6 +10330,8 @@ size(t) ==> 12 let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_SliceOp : TF_Op<"Slice", [NoSideEffect]> { @@ -9251,6 +10707,41 @@ backpropagation, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { + let summary = [{ +Multiply matrix "a" by matrix "b". + }]; + + let description = [{ +The inputs must be two-dimensional matrices and the inner dimension of "a" must +match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +`SparseTensor`s. This op is optimized for the case where at least one of "a" or +"b" is sparse, in the sense that they have a large proportion of zero values. +The breakeven for using this versus a dense matrix multiply on one platform was +30% zero values in the sparse matrix. + +The gradient computation of this operation will only take advantage of sparsity +in the input gradient when that gradient comes from a Relu. + }]; + + let arguments = (ins + TensorOf<[BF16, F32]>:$a, + TensorOf<[BF16, F32]>:$b, + + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$a_is_sparse, + DefaultValuedAttr:$b_is_sparse + ); + + let results = (outs + F32Tensor:$product + ); + + TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; +} + def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> { let summary = [{ Reshapes a SparseTensor to represent values in a new dense shape. @@ -9482,7 +10973,7 @@ I.e., \\(y = x * x = x^2\\). def TF_SquaredDifferenceOp : TF_Op<"SquaredDifference", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { - let summary = "Returns (x - y)(x - y) element-wise."; + let summary = "Returns conj(x - y)(x - y) element-wise."; let description = [{ *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting @@ -9542,7 +11033,7 @@ def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> { let summary = "Delete the stack from its resource container."; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs); @@ -9552,7 +11043,7 @@ def TF_StackPopV2Op : TF_Op<"StackPopV2", []> { let summary = "Pop the element at the top of the stack."; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs @@ -9566,7 +11057,7 @@ def TF_StackPushV2Op : TF_Op<"StackPushV2", []> { let summary = "Push an element onto the stack."; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, TF_Tensor:$elem, DefaultValuedAttr:$swap_memory @@ -9590,10 +11081,53 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); let results = (outs - TF_ResourceTensor:$handle + Res:$handle ); } +def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { + let summary = "Draws samples from a multinomial distribution."; + + let arguments = (ins + TF_IntOrFpTensor:$logits, + I32Tensor:$num_samples, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom values from a normal distribution. + }]; + + let description = [{ +The generated values will have mean 0 and standard deviation 1. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> { let summary = [{ Outputs deterministic pseudorandom random values from a uniform distribution. @@ -9620,6 +11154,33 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom random integers from a uniform distribution. + }]; + + let description = [{ +The generated values follow a uniform distribution in the range `[minval, maxval)`. + +The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TF_I32OrI64Tensor:$minval, + TF_I32OrI64Tensor:$maxval + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> { let summary = [{ Outputs deterministic pseudorandom values from a truncated normal distribution. @@ -9889,7 +11450,37 @@ Examples: TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } -def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, +def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> { + let summary = [{ +Converts each string in the input Tensor to its hash mod by a number of buckets. + }]; + + let description = [{ +The hash function is deterministic on the content of the string within the +process and will never change. However, it is not suitable for cryptography. +This function may be used when CPU time is scarce and inputs are trusted or +unimportant. There is a risk of adversaries constructing inputs that all hash +to the same bucket. To prevent this problem, use a strong hash function with +`tf.string_to_hash_bucket_strong`. + +Examples: + +>>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy() +array([0, 2, 2]) + }]; + + let arguments = (ins + TF_StrTensor:$input, + + Confined]>:$num_buckets + ); + + let results = (outs + I64Tensor:$output + ); +} + +def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; @@ -9944,6 +11535,25 @@ retained with length 1. >]; } +def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> { + let summary = [{ +Computes the gradient function for function f via backpropagation. + }]; + + let arguments = (ins + Variadic:$input, + + SymbolRefAttr:$f + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; +} + def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> { let summary = "Returns the result of a TPU compilation."; @@ -10158,9 +11768,9 @@ variables. }]; let arguments = (ins - Variadic:$vars, + Arg, "", [TF_VariableRead, TF_VariableWrite]>:$vars, TF_StrTensor:$new_format_key, - TF_ResourceTensor:$format_state_var + Arg:$format_state_var ); let results = (outs); @@ -10248,7 +11858,7 @@ of a step/run. }]; let arguments = (ins - TF_ResourceTensor:$handle + Arg:$handle ); let results = (outs); @@ -10272,7 +11882,7 @@ All elements must have the same shape (excepting the first dimension). }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, F32Tensor:$flow_in, DefaultValuedAttr:$element_shape_except0 @@ -10296,7 +11906,7 @@ All elements selected by `indices` must have the same shape. }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, I32Tensor:$indices, F32Tensor:$flow_in, @@ -10355,14 +11965,14 @@ calculation gets its own TensorArray accumulator. }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, F32Tensor:$flow_in, StrAttr:$source ); let results = (outs - TF_ResourceTensor:$grad_handle, + Res:$grad_handle, F32Tensor:$flow_out ); } @@ -10371,7 +11981,7 @@ def TF_TensorArrayReadV3Op : TF_Op<"TensorArrayReadV3", []> { let summary = "Read an element from the TensorArray into output `value`."; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, I32Tensor:$index, F32Tensor:$flow_in ); @@ -10393,7 +12003,7 @@ Scatter the data from the input value into specific TensorArray elements. }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, I32Tensor:$indices, TF_Tensor:$value, F32Tensor:$flow_in @@ -10410,7 +12020,7 @@ def TF_TensorArraySizeV3Op : TF_Op<"TensorArraySizeV3", []> { let summary = "Get the current size of the TensorArray."; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, F32Tensor:$flow_in ); @@ -10445,7 +12055,7 @@ and having size }]; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, TF_Tensor:$value, I64Tensor:$lengths, F32Tensor:$flow_in @@ -10477,7 +12087,7 @@ Write data via Write and read via Read or Pack. ); let results = (outs - TF_ResourceTensor:$handle, + Res:$handle, F32Tensor:$flow ); } @@ -10486,7 +12096,7 @@ def TF_TensorArrayWriteV3Op : TF_Op<"TensorArrayWriteV3", []> { let summary = "Push an element onto the tensor_array."; let arguments = (ins - TF_ResourceTensor:$handle, + Arg:$handle, I32Tensor:$index, TF_Tensor:$value, F32Tensor:$flow_in @@ -10881,6 +12491,40 @@ On GPU, if an out of bound index is found, the index is ignored. ]; } +def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> { + let summary = "Assign `value` to the sliced l-value reference of `input`."; + + let description = [{ +The values of `value` are assigned to the positions in the tensor `input` that +are selected by the slice parameters. The slice parameters `begin` `end` +`strides` etc. work exactly as in `StridedSlice`. + +NOTE this op currently does not support broadcasting and so `value`'s shape +must be exactly the shape produced by the slice of `input`. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$begin, + TF_I32OrI64Tensor:$end, + TF_I32OrI64Tensor:$strides, + TF_Tensor:$value, + + DefaultValuedAttr:$begin_mask, + DefaultValuedAttr:$end_mask, + DefaultValuedAttr:$ellipsis_mask, + DefaultValuedAttr:$new_axis_mask, + DefaultValuedAttr:$shrink_axis_mask + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; +} + def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> { let summary = "Constructs a tensor by tiling a given tensor."; @@ -10925,46 +12569,9 @@ array([[1, 2, 3, 1, 2, 3], TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - // TODO(parkers): Add folds for multiples = [1,...]. - // TODO(parkers): Add errors for negative multiples and multiples.size() != - // input.rank() -} + let verifier = [{ return Verify(*this); }]; -def TF_ToBoolOp : TF_Op<"ToBool", [NoSideEffect]> { - let summary = "Converts a tensor to a scalar predicate."; - - let description = [{ -Converts a tensor to a scalar predicate with the following rules: - -- For 0D tensors, truthiness is determined by comparing against a "zero" - value. For numerical types it is the obvious zero. For strings it is the - empty string. - -- For >0D tensors, truthiness is determined by looking at the number of - elements. If has zero elements, then the result is false. Otherwise the - result is true. - -This matches the behavior of If and While for determining if a tensor counts -as true/false for a branch condition. - }]; - - let arguments = (ins - TF_Tensor:$input - ); - - let results = (outs - I1Tensor:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value value", [{ - build(builder, result, RankedTensorType::get({}, builder.getI1Type()), - value); - }]>]; - - let hasCanonicalizer = 1; + let hasFolder = 1; } def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { @@ -11388,13 +12995,51 @@ tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }]; } +def TF_UpperBoundOp : TF_Op<"UpperBound", [NoSideEffect]> { + let summary = [{ +Applies upper_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='right')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = UpperBound(sorted_sequence, values) + + result == [[1, 2, 4], + [0, 2, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_VarIsInitializedOp : TF_Op<"VarIsInitializedOp", []> { let summary = [{ Checks whether a resource handle-based variable has been initialized. }]; let arguments = (ins - TF_ResourceTensor:$resource + Arg:$resource ); let results = (outs @@ -11419,7 +13064,7 @@ shape(t) ==> [2, 2, 3] }]; let arguments = (ins - TF_ResourceTensor:$input + Arg:$input ); let results = (outs @@ -11563,14 +13208,14 @@ for binary operators. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, - TF_I32OrI64Tensor:$broadcast_dims + Arg, [{the LHS input tensor}]>:$lhs, + Arg, [{the RHS input tensor}]>:$rhs, + Arg:$broadcast_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs_output, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs_output + Res, [{the broadcasted LHS tensor}]>:$lhs_output, + Res, [{the broadcasted RHS tensor}]>:$rhs_output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -11586,13 +13231,13 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, - TF_I32OrI64Tensor:$window_strides, - TF_I32OrI64Tensor:$padding, - TF_I32OrI64Tensor:$lhs_dilation, - TF_I32OrI64Tensor:$rhs_dilation, - TF_I32OrI64Tensor:$feature_group_count, + Arg, [{the input tensor}]>:$lhs, + Arg, [{the kernel tensor}]>:$rhs, + Arg:$window_strides, + Arg:$padding, + Arg:$lhs_dilation, + Arg:$rhs_dilation, + Arg:$feature_group_count, StrAttr:$dimension_numbers, StrAttr:$precision_config @@ -11615,8 +13260,8 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, + Arg, [{the LHS tensor}]>:$lhs, + Arg, [{the RHS tensor}]>:$rhs, StrAttr:$dimension_numbers, StrAttr:$precision_config @@ -11644,8 +13289,11 @@ with dimension size equal to the rank of operand. }]; let arguments = (ins - TF_Tensor:$input, - TF_I32OrI64Tensor:$start_indices, + Arg:$input, + Arg:$start_indices, TF_I32OrI64Tensor:$size_indices ); @@ -11673,13 +13321,14 @@ Handling of out-of-bounds slice indices is implementation-defined. }]; let arguments = (ins - TF_Tensor:$input, - TF_Tensor:$update, - TF_I32OrI64Tensor:$indices + Arg:$input, + Arg:$update, + Arg:$indices ); let results = (outs - TF_Tensor:$output + Res:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -11694,9 +13343,9 @@ https://www.tensorflow.org/xla/operation_semantics#gather }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand, - TF_I32OrI64Tensor:$start_indices, - TF_I32OrI64Tensor:$slice_sizes, + Arg, [{The array we're gathering from.}]>:$operand, + Arg:$start_indices, + Arg:$slice_sizes, StrAttr:$dimension_numbers, BoolAttr:$indices_are_sorted @@ -11745,13 +13394,13 @@ Sorts a tensor. Currently only sorts in ascending order are supported. }]; let arguments = (ins - TF_IntOrFpTensor:$keys, - TF_Tensor:$values + Arg:$keys, + Arg:$values ); let results = (outs - TF_IntOrFpTensor:$sorted_keys, - TF_Tensor:$sorted_values + Res:$sorted_keys, + Res:$sorted_values ); TF_DerivedOperandTypeAttr V = TF_DerivedOperandTypeAttr<1>; @@ -11767,15 +13416,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad }]; let arguments = (ins - TF_Tensor:$input, - TF_Tensor:$padding_value, - TF_I32OrI64Tensor:$padding_low, - TF_I32OrI64Tensor:$padding_high, - TF_I32OrI64Tensor:$padding_interior + Arg:$input, + Arg:$padding_value, + Arg:$padding_low, + Arg:$padding_high, + Arg:$padding_interior ); let results = (outs - TF_Tensor:$output + Res:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; @@ -11785,6 +13434,13 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { let summary = "An op to receive a tensor from the host."; + let description = [{ +output: the tensor that will be received from the host. +Toutput: element type for output. +shape: shape for output. +key: A unique identifier for this region used to match up host transfers. + }]; + let arguments = (ins TF_ShapeAttr:$shape, StrAttr:$key @@ -11805,8 +13461,8 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reduce . }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$init_value, + Arg, [{the input tensor}]>:$input, + Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, I64ArrayAttr:$dimensions_to_reduce, SymbolRefAttr:$reducer @@ -11829,6 +13485,32 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> { ); } +def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> { + let summary = "Wraps the XLA Scatter operator documented at"; + + let description = [{ +https://www.tensorflow.org/xla/operation_semantics#scatter. + }]; + + let arguments = (ins + Arg, [{Array to be scattered into.}]>:$operand, + Arg:$scatter_indices, + Arg, [{Array containing the values that must be used for scattering.}]>:$updates, + + SymbolRefAttr:$update_computation, + StrAttr:$dimension_numbers, + BoolAttr:$indices_are_sorted + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSelfAdjointEigOp : TF_Op<"XlaSelfAdjointEig", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -11843,7 +13525,7 @@ i=0...N-1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + Arg, [{the input tensor.}]>:$a, BoolAttr:$lower, I64Attr:$max_iter, @@ -11851,8 +13533,10 @@ i=0...N-1. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$w, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + Res, [{The eigenvalues in ascending order, each repeated according to its +multiplicity.}]>:$w, + Res, [{The column v[..., :, i] is the normalized eigenvector corresponding to the +eigenvalue w[..., i].}]>:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -11861,6 +13545,12 @@ i=0...N-1. def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { let summary = "An op to send a tensor to the host."; + let description = [{ +input: the tensor that will be sent to the host. +Tinput: element type for input. +key: A unique identifier for this region used to match up host transfers. + }]; + let arguments = (ins TF_Tensor:$input, @@ -11885,7 +13575,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + Arg, [{the input tensor.}]>:$a, I64Attr:$max_iter, F32Attr:$epsilon, @@ -11893,9 +13583,10 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$s, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$u, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + Res, [{Singular values. The values are sorted in reverse order of magnitude, so +s[..., 0] is the largest value, s[..., 1] is the second largest, etc.}]>:$s, + Res, [{Left singular vectors.}]>:$u, + Res, [{Right singular vectors.}]>:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -11946,6 +13637,43 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { + let summary = "Internal FusedBatchNorm operation: reserved for internal use."; + + let description = [{ +Do not invoke this operator directly in Python. A fusion optimization is +expected to create these operators. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + Variadic>:$side_input, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$activation_mode, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; +} + def TF__FusedConv2DOp : TF_Op<"_FusedConv2D", [NoSideEffect]> { let summary = [{ Performs a convolution followed by a specified series of operations. @@ -11983,7 +13711,8 @@ create these operators. DefaultValuedAttr:$dilations, DefaultValuedAttr:$use_cudnn_on_gpu, DefaultValuedAttr:$fused_ops, - DefaultValuedAttr:$epsilon + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$leakyrelu_alpha ); let results = (outs @@ -12049,13 +13778,19 @@ Tensor of activations per table specified in the model. }]; let arguments = (ins - TF_VariantTensor:$deduplication_data, + Arg:$deduplication_data, StrAttr:$config ); let results = (outs - Variadic:$outputs + Res, [{A TensorList of embedding activations containing one Tensor per +embedding table in the model.}]>:$outputs ); TF_DerivedResultSizeAttr num_tables = TF_DerivedResultSizeAttr<0>; @@ -12067,18 +13802,17 @@ Compiles a computations for execution on one or more TPU devices. }]; let description = [{ -For the internal use of the distributed TPU compiler. Note that currently only -single TPU device is supported. +For the internal use of the distributed TPU compiler. 'mlir_module' is a serialized MLIR module with a `main` function that contains target computation. 'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not known statically at TPUReplication rewrite time. -'metadata' is a serialized TPUCompileMetadataProto describing -the shapes and types of the inputs to the computation, as well as a mapping onto -the TPU pod topology. -'program' output is a string key that is passed to the _TPUExecute op and -used to look up the program in the compilation cache. +'metadata' is a serialized TPUCompileMetadataProto describing the shapes and +types of the inputs to the computation, as well as a mapping onto the TPU pod +topology. +'program' output is a string key that is passed to the TPUExecute op and used to +look up the program in the compilation cache. }]; let arguments = (ins @@ -12115,13 +13849,35 @@ rewrite passes must replace this op with a _TPUCompileMlir op `program` output. ); } +def TF__UnaryOpsCompositionOp : TF_Op<"_UnaryOpsComposition", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is + }]; + + let description = [{ +expected to create these operators. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64]>:$x, + + StrArrayAttr:$op_names + ); + + let results = (outs + TensorOf<[F16, F32, F64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", []> { let summary = [{ A pseudo-op to represent host-side computation in an XLA program. }]; let arguments = (ins - Variadic:$inputs, + Arg, [{A list of tensors that will be sent to the host.}]>:$inputs, StrAttr:$send_key, StrAttr:$recv_key, @@ -12129,7 +13885,7 @@ A pseudo-op to represent host-side computation in an XLA program. ); let results = (outs - Variadic:$outputs + Res, [{A list of tensors that will be returned to the device.}]>:$outputs ); TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; @@ -12142,14 +13898,15 @@ A placeholder op to receive values from a running XLA computation. }]; let arguments = (ins - TF_StrTensor:$dynamic_key, + Arg:$dynamic_key, StrAttr:$key, I64Attr:$device_ordinal ); let results = (outs - Variadic:$outputs + Res, [{A list of tensors that will be received from the XLA computation.}]>:$outputs ); TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; @@ -12159,8 +13916,9 @@ def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { let summary = "A placeholder op to send values to a running XLA computation."; let arguments = (ins - Variadic:$inputs, - TF_StrTensor:$dynamic_key, + Arg, [{A list of tensors that will be sent to the XLA computation.}]>:$inputs, + Arg:$dynamic_key, StrAttr:$key, I64Attr:$device_ordinal diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 1755c975c23..1edae47cfe6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -46,7 +46,7 @@ Invariants: TODO: Make invariants more structured so that we can reference them in ops. }]; - let cppNamespace = "TF"; + let cppNamespace = "::mlir::TF"; } //===----------------------------------------------------------------------===// @@ -108,14 +108,29 @@ class TF_ResourceBase : def TF_VariableResource : TF_ResourceBase<"Variable">; def TF_StackResource : TF_ResourceBase<"Stack">; def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; +def TF_SummaryResource : TF_ResourceBase<"Summary">; +def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; def TF_VariableRead : MemRead; def TF_StackRead : MemRead; def TF_TensorArrayRead : MemRead; +def TF_LookupTableRead : MemRead; def TF_VariableWrite : MemWrite; def TF_StackWrite : MemWrite; def TF_TensorArrayWrite : MemWrite; +def TF_SummaryWrite : MemWrite; +def TF_LookupTableWrite : MemWrite; + +def TF_VariableAlloc : MemAlloc; +def TF_StackAlloc : MemAlloc; +def TF_TensorArrayAlloc : MemAlloc; +def TF_SummaryAlloc : MemAlloc; +def TF_LookupTableAlloc : MemAlloc; + +def TF_StackFree : MemFree; +def TF_TensorArrayFree : MemFree; +def TF_SummaryFree : MemFree; //===----------------------------------------------------------------------===// // TensorFlow op definitions @@ -157,20 +172,10 @@ class TF_TensorFlowType : "TensorFlow " # description # " type">, BuildableType<"getType()">; -// Any tensor element type allowed in TensorFlow ops -def TF_ElementType : Type, - "tf.dtype">; - -// Any TensorFlow tensor type -def TF_Tensor : TensorOf<[TF_ElementType]>; - //===----------------------------------------------------------------------===// // Integer types +// TODO(mgester) shouldn't this be SignedIntOfWidths? def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; @@ -191,10 +196,11 @@ def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; // Any signed integer type +// TODO(mgester) shouldn't this be SignedIntOfWidths? def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; // Any integer type -def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>; +def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; // Any integer tensor types def TF_IntTensor : TensorOf<[TF_Int]>; @@ -208,8 +214,8 @@ def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">; def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">; // Any quantized type -def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, - TF_Quint16]>; +def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, + TF_Quint16], "quantized">; //===----------------------------------------------------------------------===// // Floating-point types @@ -217,8 +223,10 @@ def TF_F32Or64 : FloatOfWidths<[32, 64]>; def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>; +def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">; + // Any floating-point tensor types -def TF_FpTensor : TensorOf<[AnyFloat]>; +def TF_FpTensor : TensorOf<[TF_Float]>; //===----------------------------------------------------------------------===// // Complex types @@ -231,10 +239,9 @@ def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; def TF_Complex128 : Complex>; def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; -def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128], - "64/128-bit complex type">; +def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; -def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>; +def TF_ComplexTensor : TensorOf<[TF_Complex]>; //===----------------------------------------------------------------------===// // String/variant/resource types @@ -248,27 +255,114 @@ def TF_VariantTensor : TensorOf<[TF_Variant]>; def TF_Resource : TF_TensorFlowType<"Resource", "resource">; def TF_ResourceTensor : TensorOf<[TF_Resource]>; +//===----------------------------------------------------------------------===// +// Reference types + +// Float reference types +def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; +def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; +def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; +def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; + +// Any float reference type +def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref], + "floating-point reference">; + +// Complex reference types +def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; +def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; + +// Any complex reference type +def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">; + +// Integer reference types +def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; +def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; +def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; +def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; + +def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; +def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; +def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; +def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; + +// Any signed integer reference type +def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref], + "signed integer reference">; + +// Any unsigned integer reference type +def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref, + TF_Uint64Ref], "unsigned integer reference">; + +// Any integer reference type +def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">; + +// Quantized reference types +def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; +def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; +def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; +def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; +def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; + +// Any quantized reference type +def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref, + TF_Quint8Ref, TF_Quint16Ref], "quantized reference">; + +// Other reference types +def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; +def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; +def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; +def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; + +// Reference tensor types +def TF_FpRefTensor : TensorOf<[TF_FloatRef]>; +def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>; + //===----------------------------------------------------------------------===// // Multi-category type constraints def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>; -def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>; +def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>; // Any integer or floating-point tensor types -def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; +def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; -def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>; +def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; -def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; +def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; -def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], - "number">; +def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex], + "number">; +def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef, + TF_ComplexRef], "number reference">; -def TF_NumberTensor : TensorOf<[TF_AnyNumber]>; +def TF_NumberTensor : TensorOf<[TF_Number]>; +def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>; -def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>; -def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>; +def TF_NumberNotQuantizedOrStr : + AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; +def TF_NumberNotQuantizedOrStrRef : + AnyTypeOf<[TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_StrRef]>; +def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; + +//===----------------------------------------------------------------------===// +// Tensor and tensor element types + +// Bool type +def TF_Bool : I<1>; + +// Any tensor element type allowed in TensorFlow ops +// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) +def TF_ElementType : Type, + "tf.dtype">; + +// Any TensorFlow tensor type +def TF_Tensor : TensorOf<[TF_ElementType]>; //===----------------------------------------------------------------------===// // TensorFlow attribute definitions diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index ec1f748367d..1eb5c89f0fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -15,12 +15,40 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" namespace mlir { namespace TF { + +//===----------------------------------------------------------------------===// +// TensorFlow Contraction Fusion. +//===----------------------------------------------------------------------===// + +struct ContractionFusion { + explicit ContractionFusion( + StringRef output_kernel, ArrayRef additional_arguments = {}, + ArrayRef additional_attributes = {}) + : output_kernel(output_kernel.str()), + additional_arguments(additional_arguments.begin(), + additional_arguments.end()), + additional_attributes(additional_attributes.begin(), + additional_attributes.end()) {} + + // Name of the output kernel implementing the contraction fusion. + std::string output_kernel; + + // Indices of additional arguments that will be forwarded to the fused + // operation (e.g. forward bias vector if fusing BiasAdd operation). + SmallVector additional_arguments; + + // Add additional attributes to the fused node. + SmallVector additional_attributes; +}; + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index 3743bdda043..3c41c04a0d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -21,7 +21,7 @@ limitations under the License. include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// -// TensorFlow interfaces +// TensorFlow Layout Optimization Interfaces. //===----------------------------------------------------------------------===// def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { @@ -104,4 +104,25 @@ def TF_FoldOperandsTransposeInterface : OpInterface<"FoldOperandsTransposeInterf }]; } +//===----------------------------------------------------------------------===// +// TensorFlow Contraction Fusion Interfaces. +//===----------------------------------------------------------------------===// + +def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface"> { + let description = [{ + A contraction fusable operation is one that can be fused into the output of + a tensor contraction (MatMul, Conv2D, etc...) operation. + + For example all element wise operations are trivially contraction fusable. + }]; + + let methods = [ + InterfaceMethod< + [{Returns contraction fusion if the operation satisfies all the fusion + requirements. Otherwise returns empty optional.}], + "Optional", "GetContractionFusion", (ins) + >, + ]; +} + #endif // TF_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index abff4c21cf1..634004038d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -55,6 +55,8 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -74,23 +76,6 @@ namespace TF { //===----------------------------------------------------------------------===// namespace { -// Returns true if the op can be duplicated. -bool CanDuplicate(Operation *op) { - // If the op is marked with the cannot duplicate trait, it cannot be - // duplicated. - if (op->hasTrait()) return false; - - // If the op has no memory side effects, it can be duplicated. - if (MemoryEffectOpInterface::hasNoEffect(op)) return true; - - // If the op is marked stateless using the `is_stateless` attribute, that - // attribute determines if the op can be duplicated. - if (auto is_stateless = op->getAttrOfType("is_stateless")) - return is_stateless.getValue(); - - // Otherwise, assume ops can be duplicated by default. - return true; -} // Returns true of the given function has a single uses (within the scope // of the module containing it and all parent modules). @@ -129,6 +114,22 @@ bool HasSingleUse(FuncOp func) { return true; } +struct TFConstantFoldInterface : public DialectFoldInterface { + TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {} + LogicalResult fold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) const final { + return TensorFlowDialect::constantFold(op, operands, results); + } +}; + +struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface { + TFDecodeAttributesInterface(Dialect *dialect) + : DialectDecodeAttributesInterface(dialect) {} + LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const { + return TensorFlowDialect::decode(input, output); + } +}; + struct TFInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -156,7 +157,7 @@ struct TFInlinerInterface : public DialectInlinerInterface { // post inlining, the function will be dead and eliminated from the IR. // So there won't be any code duplication. FuncOp func = op->getParentOfType(); - return !func || CanDuplicate(op) || HasSingleUse(func); + return !func || TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func); } //===--------------------------------------------------------------------===// @@ -183,22 +184,66 @@ struct TFInlinerInterface : public DialectInlinerInterface { // TF Dialect //===----------------------------------------------------------------------===// +// Returns true if the op can be duplicated. +bool TensorFlowDialect::CanDuplicate(Operation *op) { + // If the op is marked with the cannot duplicate trait, it cannot be + // duplicated. + if (op->hasTrait()) return false; + + // If the op has no memory side effects, it can be duplicated. + if (MemoryEffectOpInterface::hasNoEffect(op)) return true; + + // If the op is marked stateless using the `is_stateless` attribute, that + // attribute determines if the op can be duplicated. + if (auto is_stateless = op->getAttrOfType("is_stateless")) + return is_stateless.getValue(); + + // Otherwise, assume ops can be duplicated by default if its registered, else + // it cannot be for unknown ops. + return op->isRegistered(); +} + +// Returns true if the op can have side effects. +bool TensorFlowDialect::CanHaveSideEffects(Operation *op) { + // If the op has no memory side effects, it has no side effects + if (MemoryEffectOpInterface::hasNoEffect(op)) return false; + + // If the op is marked stateless using the `is_stateless` attribute, then + // it has no side effects. + if (auto is_stateless = op->getAttrOfType("is_stateless")) + return !is_stateless.getValue(); + + // Terminators defined in the TF dialect do not have side effects. + if (op->isKnownTerminator()) return false; + + // Otherwise assume that the op can have side effects. + return true; +} + std::vector *TensorFlowDialect::additional_operation_hooks_ = new std::vector(); +TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_; +TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_; + TensorFlowDialect::TensorFlowDialect(MLIRContext *context) - : Dialect(/*name=*/"tf", context) { + : Dialect(/*name=*/"tf", context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc" >(); + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc" + >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); - addInterfaces(); + addInterfaces(); addAttributes(); // Support unknown operations because not all TensorFlow operations are @@ -317,16 +362,12 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, void TensorFlowDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { - switch (attr.getKind()) { - case AttrKind::SHAPE: - PrintShapeAttr(attr.cast(), os); - break; - case AttrKind::FUNC: - PrintFuncAttr(attr.cast(), os); - break; - default: - llvm_unreachable("unexpected tensorflow attribute kind"); - } + if (auto shape_attr = attr.dyn_cast()) + PrintShapeAttr(shape_attr, os); + else if (auto func_attr = attr.dyn_cast()) + PrintFuncAttr(func_attr, os); + else + llvm_unreachable("unexpected tensorflow attribute type"); } // Parses a type registered to this dialect. @@ -335,51 +376,37 @@ Type TensorFlowDialect::parseType(DialectAsmParser &parser) const { if (parser.parseKeyword(&data)) return Type(); Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - auto typeKind = llvm::StringSwitch(data) + #define HANDLE_TF_TYPE(tftype, enumerant, name) \ - .Case(name, TensorFlowTypes::enumerant) + if (data == name) return tftype##Type::get(getContext()); // Custom TensorFlow types are handled separately at the end as they do partial // match. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - .StartsWith("resource", TensorFlowTypes::RESOURCE) - .StartsWith("variant", TensorFlowTypes::VARIANT) - .Default(0); - switch (typeKind) { - default: - return (emitError(loc, "unknown TensorFlow type: " + data), nullptr); -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - return tftype##Type::get(getContext()); -#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) -// NOLINTNEXTLINE -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - case TensorFlowTypes::RESOURCE: - return ParseResourceType(parser, loc); - case TensorFlowTypes::VARIANT: - return ParseVariantType(parser, loc); - } + if (data.startswith("resource")) return ParseResourceType(parser, loc); + if (data.startswith("variant")) return ParseVariantType(parser, loc); + return (emitError(loc, "unknown TensorFlow type: " + data), nullptr); } // Prints a type registered to this dialect. void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const { assert(ty.isa()); - switch (ty.getKind()) { - default: - llvm_unreachable("unexpected tensorflow type kind"); -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - os << name; \ - break; +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = ty.dyn_cast()) { \ + os << name; \ + return; \ + } #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - Print##tftype##Type(ty.cast(), os); \ - break; + if (auto derived_ty = ty.dyn_cast()) { \ + Print##tftype##Type(derived_ty, os); \ + return; \ + } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - } + + llvm_unreachable("unexpected tensorflow type kind"); } namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 039ed1bc3a8..2755a62a3c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" namespace mlir { namespace TF { @@ -63,6 +64,12 @@ class TensorFlowDialect : public Dialect { // Returns the string description of stateful attribute. static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } + // Returns true if the op can be duplicated during transformations. + static bool CanDuplicate(Operation *op); + + // Returns true if the op can have side effects. + static bool CanHaveSideEffects(Operation *op); + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; void printAttribute(Attribute attr, DialectAsmPrinter &os) const override; @@ -110,10 +117,35 @@ class TensorFlowDialect : public Dialect { 0, (addOperation(AbstractOperation::get(*this)), 0)...}; } + using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, + SmallVectorImpl &); + static void RegisterConstantFoldHook(ConstantFoldHook fn) { + constant_fold_hook_ = std::move(fn); + } + + static LogicalResult constantFold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); + return failure(); + } + + using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input, + ElementsAttr &output); + static void RegisterDecodeConstantHook(DecodeConstantHook fn) { + decode_constant_hook_ = std::move(fn); + } + static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) { + if (decode_constant_hook_) return decode_constant_hook_(input, output); + return failure(); + } + private: // Hook functions which may add additional operations to the dialect. // These are invoked at construction time. static std::vector *additional_operation_hooks_; + + static ConstantFoldHook constant_fold_hook_; + static DecodeConstantHook decode_constant_hook_; }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 1e99675d938..5fe19f7b0cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -68,6 +68,100 @@ class TF_TensorListInitOp : TF_Op { }]; } +def TF_CaseOp : TF_Op<"Case", []> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + Variadic:$input, + + Confined]>:$branches, + + // Used to map StatelessCase and Case op defined in TensorFlow to a common + // op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; + + let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_CaseRegionOp : TF_Op<"CaseRegion", + [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + + // Used to map StatelessCase and Case op defined in TensorFlow to a common + // op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region VariadicRegion>:$branches); + + let verifier = [{ + return Verify(*this); + }]; +} + // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, @@ -123,30 +217,6 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } - -def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Permute input tensor from `src_format` to `dst_format`"; - - let description = [{ -Input tensor must be a vector of size 4, or a 4x2 tensor. - }]; - - let arguments = (ins - TF_I32OrI64Tensor:$x, - - DefaultValuedAttr:$src_format, - DefaultValuedAttr:$dst_format - ); - - let results = (outs - TF_I32OrI64Tensor:$y - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let verifier = [{ return Verify(*this); }]; -} - def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { let summary = "Creates and returns an empty tensor list."; @@ -235,19 +305,19 @@ else_branch: A function that takes 'inputs' and returns a list of let extraClassDeclaration = [{ // Get the then branch function. - FuncOp then_func() { + FuncOp then_function() { return SymbolTable::lookupNearestSymbolFrom(*this, then_branch()); } // Get the else branch function. - FuncOp else_func() { + FuncOp else_function() { return SymbolTable::lookupNearestSymbolFrom(*this, else_branch()); } }]; } def TF_YieldOp : TF_Op<"Yield", - [Terminator, ParentOneOf<["IfRegionOp", "WhileRegionOp"]>]> { + [Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> { let summary = "Yield operation"; let description = [{ @@ -283,7 +353,7 @@ else_branch: A region that computes the outputs of the op if cond = false. }]; let arguments = (ins - TF_Tensor:$cond, + 0DTensorOf<[I1]>:$cond, // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless @@ -293,47 +363,13 @@ else_branch: A region that computes the outputs of the op if cond = false. Variadic:$output ); - TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; - TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; - let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); let verifier = [{ return Verify(*this); }]; -} -def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { - let summary = "Computes the mean of elements across dimensions of a tensor."; - - let description = [{ -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - }]; - - let arguments = (ins - TF_NumberTensor:$input, - TF_I32OrI64Tensor:$reduction_indices, - - DefaultValuedAttr:$keep_dims - ); - - let results = (outs - TF_NumberTensor:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - }]; + let hasCanonicalizer = 1; } def TF_LegacyCallOp : TF_Op<"LegacyCall", @@ -534,36 +570,6 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect] DerivedAttr shape = TF_DerivedResultShapeAttr; } -def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { - let summary = [{ -SparseMatMul is MatMul with hints on the sparseness of the matrices. - }]; - - let description = [{ -Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b -are sparse matrices. - }]; - - let arguments = (ins - TensorOf<[BF16, F32]>:$a, - TensorOf<[BF16, F32]>:$b, - - DefaultValuedAttr:$a_is_sparse, - DefaultValuedAttr:$b_is_sparse, - - DefaultValuedAttr:$transpose_a, - DefaultValuedAttr:$transpose_b - ); - - let results = (outs - TensorOf<[F32]>:$product - ); - - TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; -} - - def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", [CallOpInterface]> { let summary = @@ -655,12 +661,12 @@ body: A function that takes a list of tensors and returns another let extraClassDeclaration = [{ // Get the condition function. - FuncOp cond_func() { + FuncOp cond_function() { return SymbolTable::lookupNearestSymbolFrom(*this, cond()); } // Get the body function. - FuncOp body_func() { + FuncOp body_function() { return SymbolTable::lookupNearestSymbolFrom(*this, body()); } }]; @@ -710,8 +716,6 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion", ); let results = (outs Variadic:$output); - TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; - let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let verifier = [{ return Verify(*this); }]; @@ -787,7 +791,7 @@ Example: ); let results = (outs - TF_ResourceTensor:$resource + Res:$resource ); TF_DerivedOperandOrResultHandleTypeAttr dtype = @@ -796,45 +800,6 @@ Example: TF_DerivedOperandOrResultHandleShapeAttr<"resource">; } -// Not generated because it begins with an underscore, which isn't allowed by -// the C++ standard. -def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { - let summary = "Internal FusedBatchNorm operation: reserved for internal use"; - - let description = [{ - Do not invoke this operator directly in Python. A fusion optimization is - expected to create these operators. - }]; - - let arguments = (ins - TensorOf<[F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - Variadic>:$side_input, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$activation_mode, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - let results = (outs - TensorOf<[F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; -} - // Multiple variadic operands with different sizes are not supported by the // dialect generator, so we manually added the op. def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { @@ -1105,6 +1070,43 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; } +def TF_ToBoolOp : TF_Op<"ToBool", [NoSideEffect]> { + let summary = "Converts a tensor to a scalar predicate."; + + let description = [{ +Converts a tensor to a scalar predicate with the following rules: + +- For 0D tensors, truthiness is determined by comparing against a "zero" + value. For numerical types it is the obvious zero. For strings it is the + empty string. + +- For >0D tensors, truthiness is determined by looking at the number of + elements. If has zero elements, then the result is false. Otherwise the + result is true. + +This matches the behavior of If and While for determining if a tensor counts +as true/false for a branch condition. + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + 0DTensorOf<[I1]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value value", [{ + build(builder, result, RankedTensorType::get({}, builder.getI1Type()), + value); + }]>]; + + let hasCanonicalizer = 1; +} + def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the Bessel i0e function of `x` element-wise."; @@ -1147,36 +1149,6 @@ This function is faster and numerically stabler than `bessel_i1(x)`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> { - let summary = [{ -Converts each string in the input Tensor to its hash mod by a number of buckets. - }]; - - let description = [{ -The hash function is deterministic on the content of the string within the -process and will never change. However, it is not suitable for cryptography. -This function may be used when CPU time is scarce and inputs are trusted or -unimportant. There is a risk of adversaries constructing inputs that all hash -to the same bucket. To prevent this problem, use a strong hash function with -`tf.string_to_hash_bucket_strong`. - -Examples: - ->>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy() -array([0, 2, 2]) - }]; - - let arguments = (ins - TF_StrTensor:$input, - - Confined]>:$num_buckets - ); - - let results = (outs - I64Tensor:$output - ); -} - def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let summary = "Calls a function placed on a specified TPU device."; @@ -1211,63 +1183,6 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let verifier = [{ return VerifyPartitionedCall(*this); }]; } -class TF_FusedBatchNormOpBase : TF_Op { - let summary = "Batch normalization."; - - let description = [{ -Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -The size of 1D Tensors matches the dimension C of the 4D Tensors. - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - - // TF_LayoutSensitiveInterface: - StringRef GetOptimalLayout(const RuntimeDevices& devices); - LogicalResult UpdateDataFormat(StringRef data_format); - }]; -} - -def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2 - ); -} - -def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); -} - def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> { let summary = [{ Batches all the inputs tensors to the computation done by the function. @@ -1295,6 +1210,7 @@ So, for example, in the following code batch_timeout_micros=100000, # 100ms allowed_batch_sizes=[3, 10], batching_queue="") + ``` If more than one session.run call is simultaneously trying to compute `b` the values of `a` will be gathered, non-deterministically concatenated @@ -1338,4 +1254,625 @@ must be a Tensor or a list/tuple of Tensors. TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + +def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns 0 if the denominator is zero."; + + let description = [{ +*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x, + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; + + let description = [{ +*NOTE*: `Maximum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x, + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x / y element-wise for real types."; + + let description = [{ +If `x` and `y` are reals, this will return the floating-point division. + +*NOTE*: `Div` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + +def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. + +Both input and output have a range `(-inf, inf)`. + }]; + + let arguments = (ins + TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$x, + TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$y + ); + + let results = (outs + TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; +} + +def TF_StatefulStandardNormalV2Op : TF_Op<"StatefulStandardNormalV2", []> { + let summary = "Outputs random values from a normal distribution."; + + let description = [{ +The generated values will have mean 0 and standard deviation 1. + }]; + + let arguments = (ins + Arg:$resource, + I64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulTruncatedNormalOp : TF_Op<"StatefulTruncatedNormal", []> { + let summary = "Outputs random values from a truncated normal distribution."; + + let description = [{ +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + }]; + + let arguments = (ins + Arg:$resource, + I64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulUniformOp : TF_Op<"StatefulUniform", []> { + let summary = "Outputs random values from a uniform distribution."; + + let description = [{ +The generated values follow a uniform distribution in the range `[0, 1)`. The +lower bound 0 is included in the range, while the upper bound 1 is excluded. + }]; + + let arguments = (ins + Arg:$resource, + I64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> { + let summary = "Outputs random integers from a uniform distribution."; + + let description = [{ +The generated values are uniform integers covering the whole range of `dtype`. + }]; + + let arguments = (ins + Arg:$resource, + I64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for +// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64 +// while TensorFlow CPU/GPU kernels only support i32 and i64. +def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> { + let summary = "Outputs random integers from a uniform distribution."; + + let description = [{ +The generated values are uniform integers in the range `[minval, maxval)`. +The lower bound `minval` is included in the range, while the upper bound +`maxval` is excluded. + +The random integers are slightly biased unless `maxval - minval` is an exact +power of two. The bias is small for values of `maxval - minval` significantly +smaller than the range of the output (either `2^32` or `2^64`). + }]; + + let arguments = (ins + Arg:$resource, + I64Tensor:$algorithm, + TF_I32OrI64Tensor:$shape, + TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$minval, + TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$maxval + ); + + let results = (outs + TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>; +} + +def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> { + let summary = "Flushes and closes the summary writer."; + + let description = [{ +Also removes it from the resource manager. To reopen, use another +CreateSummaryFileWriter op. + +writer: A handle to the summary writer resource. + }]; + + let arguments = (ins + Arg:$writer + ); + + let results = (outs); +} + +// TODO(b/168035831): Model db_uri read/write. +def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> { + let summary = "Creates summary database writer accessible by given resource handle."; + + let description = [{ +This can be used to write tensors from the execution graph directly +to a database. Only SQLite is supported right now. This function +will create the schema if it doesn't exist. Entries in the Users, +Experiments, and Runs tables will be created automatically if they +don't already exist. + +writer: Handle to SummaryWriter resource to overwrite. +db_uri: For example "file:/tmp/foo.sqlite". +experiment_name: Can't contain ASCII control characters or <>. Case + sensitive. If empty, then the Run will not be associated with any + Experiment. +run_name: Can't contain ASCII control characters or <>. Case sensitive. + If empty, then each Tag will not be associated with any Run. +user_name: Must be valid as both a DNS label and Linux username. If + empty, then the Experiment will not be associated with any User. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$db_uri, + TF_StrTensor:$experiment_name, + TF_StrTensor:$run_name, + TF_StrTensor:$user_name + ); + + let results = (outs); +} + +// TODO(b/168035831): Model logdir read/write. +def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> { + let summary = "Creates a summary file writer accessible by the given resource handle."; + + let description = [{ +writer: A handle to the summary writer resource +logdir: Directory where the event file will be written. +max_queue: Size of the queue of pending events and summaries. +flush_millis: How often, in milliseconds, to flush the pending events and + summaries to disk. +filename_suffix: Every event file's name is suffixed with this suffix. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$logdir, + I32Tensor:$max_queue, + I32Tensor:$flush_millis, + TF_StrTensor:$filename_suffix + ); + + let results = (outs); +} + +def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> { + let summary = "Flushes the writer's unwritten events."; + + let description = [{ +writer: A handle to the summary writer resource. + }]; + + let arguments = (ins + Arg:$writer + ); + + let results = (outs); +} + +def TF_ImportEventOp : TF_Op<"ImportEvent", []> { + let summary = "Outputs a `tf.Event` protocol buffer."; + + let description = [{ +When CreateSummaryDbWriter is being used, this op can be useful for +importing data from event logs. + +writer: A handle to a summary writer. +event: A string containing a binary-encoded tf.Event proto. + }]; + + let arguments = (ins + Arg:$writer, + TF_StrTensor:$event + ); + + let results = (outs); +} + +def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> { + let summary = "Returns a handle to be used to access a summary writer."; + + let description = [{ +The summary writer is an in-graph resource which can be used by ops to write +summaries to event files. + +writer: the summary writer resource. Scalar handle. + }]; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container + ); + + let results = (outs + Res:$writer + ); +} + +def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { + let summary = "Writes a `Summary` protocol buffer with audio."; + + let description = [{ +The summary has up to `max_outputs` summary values containing audio. The +audio is built from `tensor` which must be 3-D with shape `[batch_size, +frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +* If `max_outputs` is greater than 1, the summary value tags are + generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 2-D of shape `[batch_size, frames]`. +sample_rate: The sample rate of the signal in hertz. +max_outputs: Max number of batch elements to generate audio for. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tag, + F32Tensor:$tensor, + F32Tensor:$sample_rate, + + Confined, [IntMinValue<1>]>:$max_outputs + ); + + let results = (outs); +} + +def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> { + let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`."; + + let description = [{ +writer: Handle of `SummaryWriter`. +step: The step to write the summary for. +tensor: A scalar string of the serialized tf.GraphDef proto. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tensor + ); + + let results = (outs); +} + +def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> { + let summary = "Writes a histogram summary."; + + let description = [{ +The generated +[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +has one summary value containing a histogram for `values`. + +This op reports an `InvalidArgument` error if any value is not finite. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Tag to use for the `Summary.Value`. +values: Any shape. Values to use to build the histogram. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tag, + TF_IntOrFpTensor:$values + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> { + let summary = "Writes a `Summary` protocol buffer with images."; + + let description = [{ +The summary has up to `max_images` summary values containing images. The +images are built from `tensor` which must be 4-D with shape `[batch_size, +height, width, channels]` and where `channels` can be: + +* 1: `tensor` is interpreted as Grayscale. +* 3: `tensor` is interpreted as RGB. +* 4: `tensor` is interpreted as RGBA. + +The images have the same number of channels as the input tensor. For float +input, the values are normalized one image at a time to fit in the range +`[0, 255]`. `uint8` values are unchanged. The op uses two different +normalization algorithms: + +* If the input values are all positive, they are rescaled so the largest one + is 255. + +* If any input value is negative, the values are shifted so input value 0.0 + is at 127. They are then rescaled so that either the smallest value is 0, + or the largest one is 255. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_images` is 1, the summary value tag is '*tag*/image'. +* If `max_images` is greater than 1, the summary value tags are + generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. + +The `bad_color` argument is the color to use in the generated images for +non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +Each element must be in the range `[0, 255]` (It represents the value of a +pixel in the output image). Non-finite values in the input tensor are +replaced by this tensor in the output image. The default value is the color +red. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 4-D of shape `[batch_size, height, width, channels]` where + `channels` is 1, 3, or 4. +max_images: Max number of batch elements to generate images for. +bad_color: Color to use for pixels with non-finite values. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tag, + TensorOf<[F16, F32, TF_Uint8]>:$tensor, + TF_Uint8Tensor:$bad_color, + + Confined, [IntMinValue<1>]>:$max_images + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> { + let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers."; + + let description = [{ +writer: A handle to a summary writer. +step: The step to write the summary for. +tensor: A tensor holding one or more serialized `Summary` protobufs to write. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tensor + ); + + let results = (outs); +} + +def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> { + let summary = "Writes a `Summary` protocol buffer with scalar values."; + + let description = [{ +The input `tag` and `value` must have the scalars. + +writer: A handle to a summary writer. +step: The step to write the summary for. +tag: Tag for the summary. +value: Value for the summary. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_StrTensor:$tag, + TF_IntOrFpTensor:$value + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + +def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> { + let summary = "Outputs a `Summary` protocol buffer with a tensor."; + + let description = [{ +writer: A handle to a summary writer. +step: The step to write the summary for. +tensor: A tensor to serialize. +tag: The summary's tag. +summary_metadata: Serialized SummaryMetadata protocol buffer containing + plugin-related metadata for this summary. + }]; + + let arguments = (ins + Arg:$writer, + I64Tensor:$step, + TF_Tensor:$tensor, + TF_StrTensor:$tag, + TF_StrTensor:$summary_metadata + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +// TODO(b/168035831): Model dataset read. +def TF_InitializeTableFromDatasetOp : TF_Op<"InitializeTableFromDataset", []> { + let summary = ""; + + let arguments = (ins + Arg:$table_handle, + TF_VariantTensor:$dataset + ); + + let results = (outs); +} + +// TODO(b/168035831): Model filename read. +def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { + let summary = "Initializes a table from a text file."; + + let description = [{ +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + }]; + + let arguments = (ins + Arg:$table_handle, + TF_StrTensor:$filename, + + Confined]>:$key_index, + Confined]>:$value_index, + Confined, [IntMinValue<-1>]>:$vocab_size, + DefaultValuedAttr:$delimiter + ); + + let results = (outs); +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 5c19f9c3daa..953236602f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -64,6 +66,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -173,6 +176,72 @@ static LogicalResult Verify(BatchMatMulV2Op op) { if (!HasRankAtLeast(op.y(), 2)) { return op.emitOpError("requires rhs operand to have rank at least two"); } + + RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x()); + RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y()); + + if (!x_ty || !y_ty) return success(); + + ArrayRef x_shape = x_ty.getShape(); + ArrayRef y_shape = y_ty.getShape(); + + // Check broadcast compatibility if both input shapes are known. + // + // The last two dimensions are non-batch dimensions that don't need to + // participate in batch dimension compatibility check. + + llvm::SmallVector result_batch_shape; + if (!OpTrait::util::getBroadcastedShape( + x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape)) + return op.emitOpError() + << "found incompatible broadcast batch dimensions for lhs shape " + << x_ty << " and rhs shape " << y_ty; + + RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); + if (!output_ty) return success(); + + int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank()); + if (output_ty.getRank() != expected_output_rank) + return op.emitOpError() + << "found invalid output rank, expected " << expected_output_rank + << " but got " << output_ty.getRank(); + + // Check output batch dim with potential broadcasting. + ArrayRef output_shape = output_ty.getShape(); + for (int i = 0; i < result_batch_shape.size(); ++i) { + if (output_shape[i] != ShapedType::kDynamicSize && + output_shape[i] != result_batch_shape[i]) + return op.emitOpError() + << "has mismatching input batch dimension " + << result_batch_shape[i] << " and output batch dimension " + << output_shape[i]; + } + + // Check output shape for non-batch dimension, following documentation below. + // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul + int64_t x_row_dim = x_shape[x_shape.size() - 2]; + int64_t x_col_dim = x_shape[x_shape.size() - 1]; + int64_t y_row_dim = y_shape[y_shape.size() - 2]; + int64_t y_col_dim = y_shape[y_shape.size() - 1]; + int64_t out_row_dim = output_shape[output_shape.size() - 2]; + int64_t out_col_dim = output_shape[output_shape.size() - 1]; + + int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim; + int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim; + + if (expected_out_row_dim != ShapedType::kDynamicSize && + out_row_dim != ShapedType::kDynamicSize && + out_row_dim != expected_out_row_dim) + return op.emitOpError() + << "found invalid output dimension on row, expected " + << expected_out_row_dim << " but got " << out_row_dim; + if (expected_out_col_dim != ShapedType::kDynamicSize && + out_col_dim != ShapedType::kDynamicSize && + out_col_dim != expected_out_col_dim) + return op.emitOpError() + << "found invalid output dimension on col, expected " + << expected_out_col_dim << " but got " << out_col_dim; + return success(); } @@ -187,7 +256,7 @@ void BatchMatMulV2Op::getCanonicalizationPatterns( static LogicalResult Verify(BatchToSpaceOp op) { // Op already has a constraint that block_size >= 2. - int64_t block_size = op.block_size().getSExtValue(); + int64_t block_size = op.block_size(); llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); auto input_type = op.input().getType().cast(); @@ -339,15 +408,19 @@ void BatchToSpaceOp::getCanonicalizationPatterns( // are not unknown. // static LogicalResult Verify(BiasAddOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); + tensorflow::TensorFormat format; + bool is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + if (format == tensorflow::TensorFormat::FORMAT_NHWC) { if (!HasRankAtLeast(op.value(), 2)) return op.emitOpError( "requires value operand to have rank at least two with `NHWC` data " "format"); } else { // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); + DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); if (!HasRankAtLeast(op.value(), 3)) return op.emitOpError( "requires value operand to have rank at least three with `NCHW` data " @@ -361,9 +434,8 @@ static LogicalResult Verify(BiasAddOp op) { RankedTensorType bias_ty = op.bias().getType().dyn_cast(); if (!bias_ty || !value_ty) return success(); - // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute - // dimension indices based on format. - int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1; + int64_t feature_dim_idx = + tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format); int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); int64_t bias_len = bias_ty.getDimSize(0); if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { @@ -375,6 +447,13 @@ static LogicalResult Verify(BiasAddOp op) { return success(); } +Optional BiasAddOp::GetContractionFusion() { + // Only NHWC in f32 is supported for fusion. + if (data_format() != "NHWC" || !T().isF32()) return None; + + return ContractionFusion("BiasAdd", /*additional_arguments=*/{1}); +} + //===----------------------------------------------------------------------===// // BiasAddGradOp //===----------------------------------------------------------------------===// @@ -383,15 +462,19 @@ static LogicalResult Verify(BiasAddOp op) { // * the out_backprop operands have valid ranks or are unranked. // static LogicalResult Verify(BiasAddGradOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); + tensorflow::TensorFormat format; + bool is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + if (format == tensorflow::TensorFormat::FORMAT_NHWC) { if (!HasRankAtLeast(op.out_backprop(), 2)) return op.emitOpError( "requires out_backprop operand to have rank at least two with `NHWC` " "data format"); } else { // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); + DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); if (!HasRankAtLeast(op.out_backprop(), 3)) return op.emitOpError( "requires out_backprop operand to have rank at least three with " @@ -431,6 +514,19 @@ static LogicalResult Verify(BroadcastToOp op) { return success(); } +OpFoldResult BroadcastToOp::fold(ArrayRef operands) { + Value input = this->input(); + + // Fold broadcast if operand and result types are the same and all dimensions + // are statically known (no-op broadcast). + auto result_ty = getType().dyn_cast(); + if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) { + return input; + } + + return {}; +} + //===----------------------------------------------------------------------===// // CaseOp //===----------------------------------------------------------------------===// @@ -449,28 +545,119 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( DenseIntElementsAttr branch; if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); - // Only attempt to fold scalar valued case statements. - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (!branch.getType().cast().getShape().empty()) - return failure(); - int index = *branch.getValues().begin(); - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (index >= op.branches().size()) return failure(); + if (index < 0 || index >= op.branches().size()) + index = op.branches().size() - 1; auto func = op.branches()[index].cast(); auto empty = rewriter.getStringAttr(""); auto call_op = rewriter.create( op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op); rewriter.replaceOp(op, call_op.getResults()); return success(); } void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); +} + +static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) { + if (!IsOfRankOrUnranked(branch_index, 0)) + return op->emitOpError() + << "expects 'branch_index' to be a scalar, but got " + << branch_index.getType(); + return success(); +} + +static LogicalResult VerifyCaseOrIfOpBranchFunctions( + Operation *op, ArrayRef branches, + llvm::function_ref branch_name) { + SmallVector branch_types; + branch_types.reserve(branches.size()); + + // Functions have one less operand compared to op as first operand is elided + // (`cond` of `tf.If` and `branch_index` of `tf.Case`). + TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"}; + TypeRangeWithDesc result{op->getResultTypes(), "result"}; + + for (auto branch : llvm::enumerate(branches)) { + auto branch_func = SymbolTable::lookupNearestSymbolFrom( + op, branch.value().cast()); + if (!branch_func) + return op->emitOpError() + << "expects " << branch_name(branch.index()) << " (" + << branch.value() << ") to point to a defined function"; + + FunctionType branch_type = branch_func.getType(); + std::string desc = branch_name(branch.index()) + " input"; + TypeRangeWithDesc branch_input{branch_type.getInputs(), desc}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input))) + return failure(); + + desc = branch_name(branch.index()) + " result"; + TypeRangeWithDesc branch_result{branch_type.getResults(), desc}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result))) + return failure(); + + branch_types.push_back(branch_type); + } + + // If branches have incompatible input types that means that no tensor can + // serve as input to all the functions. Hence, the op is invalid. + int expected_num_inputs = op->getNumOperands() - 1; + for (int i = 0; i < expected_num_inputs; ++i) { + SmallVector branch_input_i_types; + branch_input_i_types.reserve(branches.size()); + llvm::transform( + branch_types, std::back_inserter(branch_input_i_types), + [i](FunctionType &branch_type) { return branch_type.getInput(i); }); + if (!AreCastCompatible(branch_input_i_types)) { + std::string input_types_str; + llvm::raw_string_ostream os(input_types_str); + llvm::interleaveComma(branch_input_i_types, os); + return op->emitOpError() + << "expects all branch input type(s) (" << os.str() + << ") at index " << i << " to be cast compatible"; + } + } + + return success(); +} + +static LogicalResult Verify(CaseOp op) { + if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + auto branch_name = [](unsigned index) { + return llvm::formatv("branch #{0}", index).str(); + }; + return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(), + branch_name); +} + +//===----------------------------------------------------------------------===// +// CaseRegionOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseRegionOp op) { + if (op.branches().empty()) + return op.emitOpError() << "expects to have at least 1 region"; + + if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + + TypeRangeWithDesc results{op.getResultTypes(), "result"}; + + for (auto region_and_idx : llvm::enumerate(op.branches())) { + std::string description = + llvm::formatv("branch #{0} result", region_and_idx.index()).str(); + Operation *yield = region_and_idx.value().front().getTerminator(); + TypeRangeWithDesc branch_results{yield->getOperandTypes(), description}; + if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results))) + return failure(); + } + + return success(); } //===----------------------------------------------------------------------===// @@ -727,6 +914,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, context); } +//===----------------------------------------------------------------------===// +// CumsumOp and CumprodOp +//===----------------------------------------------------------------------===// + +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + if (!IsOfRankOrUnranked(op.axis(), 0)) + return op.emitOpError("requires scalar axis operand"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + auto input_ty = op.x().getType().template dyn_cast(); + if (input_ty) { + int64_t rank = input_ty.getRank(); + assert(axis_attr.getNumElements() == 1 && + "scalar attribute should have exactly one element"); + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < -rank || axis >= rank) { + return op.emitError() + << "axis operand should be within range [" << -rank << ", " + << rank << "); actual value: " << axis; + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConcatOffsetOp //===----------------------------------------------------------------------===// @@ -990,7 +1206,8 @@ static LogicalResult Verify(OpT op) { int64_t input_channels = -1; if (auto ty = op.input().getType().template dyn_cast()) { - std::string data_format = op.data_format().str(); + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); tensorflow::TensorFormat format; auto is_valid = FormatFromString(data_format, &format); DCHECK(is_valid) << data_format; @@ -1475,7 +1692,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + "," + Twine(std::to_string(rmax)) + "]"); } - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1495,7 +1712,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { if (max && !IsOfRankedFloatTensorType(max, 0)) return op.emitOpError("requires max to be a 0d float tensor"); - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1519,7 +1736,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { if (!HasRankAtLeast(inputs, 1)) return op.emitError("requires inputs to be at least 1d float tensor"); - int64_t num_bits = op.num_bits().getSExtValue(); + int64_t num_bits = op.num_bits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -1722,7 +1939,7 @@ StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { //===----------------------------------------------------------------------===// static LogicalResult Verify(GatherV2Op op) { - int64_t batch_dims = op.batch_dims().getSExtValue(); + int64_t batch_dims = op.batch_dims(); if (auto ty = op.indices().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) @@ -1760,79 +1977,18 @@ static LogicalResult Verify(GatherV2Op op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(IfOp op) { - auto then_fn = op.then_func(); - if (!then_fn) - return op.emitOpError("then_branch refers to an undefined function : ") - << op.then_branch(); - auto else_fn = op.else_func(); - if (!else_fn) - return op.emitOpError("else_branch refers to an undefined function : ") - << op.else_branch(); - auto then_fn_type = then_fn.getType(); - auto else_fn_type = else_fn.getType(); - - // Non-conditional operands starting with the second operand are passed to - // branches and should be pair-wise compatible with branches' inputs. - unsigned expected_num_inputs = op.getNumOperands() - 1; - if (then_fn_type.getNumInputs() != expected_num_inputs || - else_fn_type.getNumInputs() != expected_num_inputs) - return op.emitError("branches should have " + Twine(expected_num_inputs) + - " inputs"); - - for (unsigned i = 0; i < expected_num_inputs; ++i) { - auto operand_type = op.getOperand(i + 1).getType().cast(); - auto then_input_type = then_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, then_input_type})) - return op.emitError( - llvm::formatv("then branch input type {0} is incompatible with " - "operand type {1} at index {2}", - then_input_type, operand_type, i)); - - auto else_input_type = else_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, else_input_type})) - return op.emitError( - llvm::formatv("else branch input type {0} is incompatible with " - "operand type {1} at index {2}", - else_input_type, operand_type, i)); - - // If branches have incompatible input types that means that no tensor can - // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({then_input_type, else_input_type})) - return op.emitError(llvm::formatv( - "branches inputs have incompatible types {0} and {1} at index {2}", - then_input_type, else_input_type, i)); - } - - // Branches' results should be pair-wise compatible with the op results. - unsigned expected_num_results = op.getNumResults(); - if (then_fn_type.getNumResults() != expected_num_results || - else_fn_type.getNumResults() != expected_num_results) - return op.emitError("branches should have " + Twine(expected_num_results) + - " results"); - - for (unsigned i = 0; i < expected_num_results; ++i) { - auto result_type = op.getResult(i).getType().cast(); - auto then_result_type = then_fn_type.getResult(i).cast(); - if (!AreCastCompatible({then_result_type, result_type})) - return op.emitError( - llvm::formatv("then branch result type {0} is incompatible with op " - "result type {1} at index {2}", - then_result_type, result_type, i)); - - auto else_result_type = else_fn_type.getResult(i).cast(); - if (!AreCastCompatible({else_result_type, result_type})) - return op.emitError( - llvm::formatv("else branch result type {0} is incompatible with op " - "result type {1} at index {2}", - else_result_type, result_type, i)); - } - return success(); + auto branch_name = [](unsigned index) -> std::string { + return index == 0 ? "'then_branch'" : "'else_branch'"; + }; + return VerifyCaseOrIfOpBranchFunctions( + op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name); } //===----------------------------------------------------------------------===// // IfOp canonicalization. //===----------------------------------------------------------------------===// +namespace { class FoldConstantIfOp : public OpRewritePattern { public: explicit FoldConstantIfOp(MLIRContext *context) @@ -1864,9 +2020,9 @@ LogicalResult FoldConstantIfOp::matchAndRewrite( auto rewrite = [&](auto op_type) { auto empty = rewriter.getStringAttr(""); auto call_op = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + op.getLoc(), op.getResultTypes(), op.input(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op); rewriter.replaceOp(op, call_op.getResults()); }; @@ -1877,6 +2033,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite( return success(); } +} // anonymous namespace void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { @@ -1888,13 +2045,77 @@ void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult Verify(IfRegionOp op) { - if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + TypeRange then_types = + op.then_branch().front().getTerminator()->getOperandTypes(); + TypeRange else_types = + op.else_branch().front().getTerminator()->getOperandTypes(); + + TypeRangeWithDesc results{op.getResultTypes(), "result"}; + TypeRangeWithDesc then_results{then_types, "then result"}; + TypeRangeWithDesc else_results{else_types, "else result"}; + + if (failed(VerifyTypeRangesAreCompatible(op, then_results, results))) return failure(); - if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + if (failed(VerifyTypeRangesAreCompatible(op, else_results, results))) return failure(); return success(); } +namespace { +class FoldConstantIfRegionOp : public OpRewritePattern { + public: + explicit FoldConstantIfRegionOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::IfRegionOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult FoldConstantIfRegionOp::matchAndRewrite( + TF::IfRegionOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr cond_attr; + if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + + // IfRegion condition should always be a scalar. Select the region to fold to. + bool cond = cond_attr.getSplatValue().getValue(); + Region ®ion = cond ? op.then_branch() : op.else_branch(); + + // If the IfRegion is stateless but the region being inlined itself is not + // stateless, then inlining the region could cause a loss of information. + // However, its probably better to fold the IfRegion instead of having the + // dead branch stay. + + // Inline the region in place of the IfRegion op, and forward the yield + // inputs to the IfRegion op results. This is possible only if the yield + // types match the result types. + auto yield = cast(region.front().getTerminator()); + auto updated_results = llvm::to_vector<4>(yield.getOperands()); + + // If the yield types do not match the IfRegion result types, add appropriate + // casts. + rewriter.setInsertionPoint(yield); + for (auto it : llvm::zip(op.getResultTypes(), updated_results)) { + auto &updated_result = std::get<1>(it); + Type result_type = std::get<0>(it); + if (result_type != updated_result.getType()) { + updated_result = + rewriter.create(op.getLoc(), result_type, updated_result, + /*Truncate=*/rewriter.getBoolAttr(false)); + } + } + // Inline the region into the block containing the IfRegion. + rewriter.mergeBlockBefore(®ion.front(), op); + rewriter.eraseOp(yield); + rewriter.replaceOp(op, updated_results); + return success(); +} +} // anonymous namespace + +void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // InvertOp //===----------------------------------------------------------------------===// @@ -1943,6 +2164,15 @@ OpFoldResult LeakyReluOp::fold(ArrayRef operands) { return {}; } +Optional LeakyReluOp::GetContractionFusion() { + // Only f32 is supported for fusion. + if (!T().isF32()) return None; + + NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr()); + return ContractionFusion("LeakyRelu", /*additional_arguments=*/{}, + /*additional_attributes=*/{alpha}); +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// @@ -2064,12 +2294,12 @@ OpFoldResult MulOp::fold(ArrayRef operands) { return IdentityArithmeticOpFolder(*this, operands); } +} // namespace TF +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h index 19a927a23d7..8d98632b198 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h @@ -43,6 +43,9 @@ namespace TF { class YieldOp; +} // namespace TF +} // namespace mlir + // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose // purpose is to catch bug on `tensorflow::mutex_lock`. We don't use // `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and @@ -56,7 +59,4 @@ class YieldOp; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index 71f1560aa6c..44df2b12d88 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -18,17 +18,6 @@ limitations under the License. // tf_verifiers or tf_ops. // TODO(jpienaar): Remove this file post refactoring. -// Propagates underscore and device attributes from src to dst. -// TODO(b/158769932): This should be a general feature instead post some policy -// discussion. -static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) { - auto device = mlir::Identifier::get("device", src->getContext()); - for (auto named_attr : src->getAttrs()) { - if (*named_attr.first.begin() == '_' || named_attr.first == device) - dst->setAttr(named_attr.first, named_attr.second); - } -} - //===----------------------------------------------------------------------===// // TF op helper functions //===----------------------------------------------------------------------===// @@ -554,27 +543,27 @@ static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, return success(); } -LogicalResult VerifyRegionResults(Operation *op, Region ®ion, - StringRef region_name) { - auto op_name = op->getName().getStringRef(); - // verify that op outputs match yield inputs - YieldOp yield = cast(region.front().getTerminator()); - unsigned expected_num_results = op->getNumResults(); - if (yield.getNumOperands() != expected_num_results) - return op->emitOpError() - << region_name + " should have same number (" << expected_num_results - << ") of results as " << op_name << " but has " - << yield.getNumOperands() << " results"; +// A type range with description (in singular form) attached to it. +using TypeRangeWithDesc = std::pair; - for (int idx : llvm::seq(0, expected_num_results)) { - auto op_result_type = op->getResult(idx).getType().cast(); - auto region_result_type = - yield.getOperand(idx).getType().cast(); - if (!AreCastCompatible({region_result_type, op_result_type})) - return op->emitError(llvm::formatv( - "{0} result type {1} is incompatible with {2} " - "result type {3} at index {4}", - region_name, region_result_type, op_name, op_result_type, idx)); +LogicalResult VerifyTypeRangesAreCompatible(Operation *op, + TypeRangeWithDesc range0, + TypeRangeWithDesc range1) { + if (range0.first.size() != range1.first.size()) { + return op->emitOpError() + << range0.second << "s (size = " << range0.first.size() << ")" + << " should have the same number of values as " << range1.second + << "s (size = " << range1.first.size() << ")"; + } + + for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) { + int index = it.index(); + Type type0 = std::get<0>(it.value()); + Type type1 = std::get<1>(it.value()); + if (!AreCastCompatible({type0, type1})) + return op->emitOpError(llvm::formatv( + "{0} type {1} is incompatible with {2} type {3} at index {4}", + range0.second, type0, range1.second, type1, index)); } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index ffedcb47f7e..c2f39733c7a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -109,7 +109,7 @@ void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, //===----------------------------------------------------------------------===// static LogicalResult Verify(OneHotOp op) { - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); auto indices_ty = op.indices().getType().dyn_cast(); if (indices_ty && @@ -207,7 +207,7 @@ static LogicalResult Verify(PackOp op) { // the axis value range is [-(R+1), R+1). int64_t range_begin = -inputs_rank - 1; // Inclusive int64_t range_end = inputs_rank + 1; // Exclusive - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < range_begin || axis >= range_end) { return op.emitError() << "attribute 'axis' should be within range [" << range_begin << ", " << range_end @@ -232,7 +232,7 @@ OpFoldResult PackOp::fold(ArrayRef operands) { if (values().size() < 2) return {}; // Dimensions packed along axis = 0 (pack scalars into vector). - if (axis().getSExtValue() != 0) return {}; + if (axis() != 0) return {}; // First packed value is defined by a strided slice operation. auto slice_op = dyn_cast_or_null(values()[0].getDefiningOp()); @@ -247,11 +247,9 @@ OpFoldResult PackOp::fold(ArrayRef operands) { // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing // scalar value from input vector). - if (slice_op.begin_mask().getSExtValue() != 0 || - slice_op.ellipsis_mask().getSExtValue() != 0 || - slice_op.end_mask().getSExtValue() != 0 || - slice_op.new_axis_mask().getSExtValue() != 0 || - slice_op.shrink_axis_mask().getSExtValue() != 1) + if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 || + slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 || + slice_op.shrink_axis_mask() != 1) return {}; // Returns a value if the `value` is defined by a ConstOp with a single @@ -566,6 +564,17 @@ OpFoldResult RealDivOp::fold(ArrayRef operands) { return IdentityArithmeticOpFolder(*this, operands); } +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +Optional ReluOp::GetContractionFusion() { + // Only f32 is supported for fusion. + if (!T().isF32()) return None; + + return ContractionFusion("Relu", /*additional_arguments=*/{}); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -707,7 +716,6 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). - // TODO(ezhulenev): Add the same folding for BroadcastToOp. auto result_ty = getType().dyn_cast(); if (result_ty && result_ty.hasStaticShape() && result_ty == tensor.getType()) { @@ -932,24 +940,75 @@ static LogicalResult Verify(ShapeNOp op) { return success(); } -LogicalResult ShapeNOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - if (getNumOperands() == 0) return success(); - int width = - getType(0).cast().getElementType().getIntOrFloatBitWidth(); - - for (Type input_ty : getOperandTypes()) { - OpFoldResult result = ConvertShapeToAttr(input_ty, width); - if (!result) return failure(); - - results.push_back(result); - } - return success(); -} - -// TODO(hinsu): Add canonicalization pattern for ShapeN ops that don't have all +namespace { +// Canonicalization pattern for ShapeNOp that don't have all // static input shapes. Replacing output values corresponding to static input // types may enable optimizations in users of the values. +class ShapeNPartialStaticInputShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ShapeNOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() == 0) { + rewriter.eraseOp(op); + return success(); + } + + int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth(); + + SmallVector results(op.getNumOperands()); + SmallVector dynamic_indices; + SmallVector dynamic_inputs; + SmallVector result_types; + for (auto e : llvm::enumerate(op.getOperands())) { + if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) { + results[e.index()] = rewriter.create(op.getLoc(), result); + } else { + dynamic_indices.push_back(e.index()); + dynamic_inputs.push_back(e.value()); + result_types.push_back(op.getType(e.index())); + } + } + + if (dynamic_inputs.size() == op.getNumOperands()) { + // Cannot canonicalize ShapeN if all inputs are dynamic. + return failure(); + } + + // Create a ShapeNOp for all dynamic inputs. + if (!dynamic_inputs.empty()) { + auto dynamic_shape_n = rewriter.create( + op.getLoc(), result_types, dynamic_inputs); + for (auto index_result : + llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) { + results[std::get<0>(index_result)] = std::get<1>(index_result); + } + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Canonicalize ShapeNOp to ShapeOp if there is only one operand. +class ShapeNToShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ShapeNOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() != 1) { + return failure(); + } + auto shape = rewriter.create(op.getLoc(), op.getType(0), + op.getOperand(0)); + rewriter.replaceOp(op, {shape}); + return success(); + } +}; +} // namespace + +void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} //===----------------------------------------------------------------------===// // SizeOp @@ -964,9 +1023,23 @@ static LogicalResult Verify(SizeOp op) { return op.emitOpError( "requires ranked input tensor to be of rank INT32_MAX or less"); + // Output type needs to be scalar. + if (!IsOfRankOrUnranked(op.output(), /*rank=*/0)) + return op.emitOpError("requires scalar output"); + return success(); } +OpFoldResult SizeOp::fold(ArrayRef operands) { + ShapedType output_type = getType().cast(); + ShapedType input_type = getOperand().getType().cast(); + if (!input_type.hasStaticShape()) return {}; + int size = input_type.getNumElements(); + return DenseElementsAttr::get( + output_type, + IntegerAttr::get(output_type.getElementType(), /*value=*/size)); +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// @@ -978,8 +1051,11 @@ static LogicalResult Verify(SizeOp op) { // of elements in operands begin and size. // - if begin are constants, that // 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] +// and +// size[i] == output_ty.getShape()[i] // - if begins aren't constant but the input is a ranked tensor, that // size[i] <= input_ty.getShape()[i] +// - output rank is the same as input rank // static LogicalResult Verify(SliceOp op) { RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); @@ -1007,21 +1083,40 @@ static LogicalResult Verify(SliceOp op) { "are equal to input rank"; } + auto output_ty = op.output().getType().dyn_cast(); + if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) { + return op.emitOpError() + << "requires output to have the same rank as input, but got input " + "rank " + << input_ty.getRank() << " and output rank " << output_ty.getRank(); + } + DenseIntElementsAttr begin_indices; if (matchPattern(op.begin(), m_Constant(&begin_indices))) { DenseIntElementsAttr slice_sizes; bool constant_slice_sizes = matchPattern(op.size(), m_Constant(&slice_sizes)); int dim = 0; + // TODO(jpienaar): Reformulate the shape verification below to not use magic + // constants. for (const APInt &raw_begin_index : begin_indices.getValues()) { int64_t begin_index = raw_begin_index.getSExtValue(); int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; int64_t slice_size = constant_slice_sizes ? slice_sizes.getValue(dim).getSExtValue() : 0; + int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1; + if (slice_size == -1 && input_size != -1) { slice_size = input_size - begin_index; } + if (output_size != -1 && constant_slice_sizes && + output_size != slice_size) { + return op.emitOpError() + << "requires output size to have the same size of slice, got " + "slice size " + << slice_size << " and output size " << output_size; + } if (begin_index < 0 || (input_size != -1 && begin_index + slice_size > input_size)) { return op.emitOpError() @@ -1079,6 +1174,13 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SpaceToBatchNDOp +//===----------------------------------------------------------------------===// + +// TODO(b/157475606): Add Verify(SpaceToBatchNDOp) +// TODO(b/157475606): Add SpaceToBatchNDOp::inferReturnTypes + //===----------------------------------------------------------------------===// // SparseSoftmaxCrossEntropyWithLogitsOp //===----------------------------------------------------------------------===// @@ -1325,7 +1427,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there // exists only no more than one ellipsis. - uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); + uint32_t ellipsis_mask = op.ellipsis_mask(); if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) return op.emitOpError("cannot have multiple ellipses"); @@ -1581,10 +1683,9 @@ bool StridedSliceOp::GetSlicedBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), + end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + slice_begin, slice_end, slice_stride); return true; } @@ -1635,10 +1736,9 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - *input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), + end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + slice_begin, slice_end, slice_stride); return true; } @@ -1712,6 +1812,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// - input has at least rank 1 +// - multiples is rank 1 +// - multiples.size() == input.rank() +// - input.rank() == output.rank() +// - Elements in multiples are non-negative +// - input.shape[i] * multiples[i] == output.shape[i] +// for i in [0, input.rank() - 1] + +static LogicalResult Verify(TileOp op) { + auto input_type = op.input().getType().dyn_cast(); + auto multiples_type = op.multiples().getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); + + if (multiples_type && multiples_type.getRank() != 1) { + return op.emitOpError() << "expected multiples to be rank 1, got rank = " + << multiples_type.getRank(); + } + + if (input_type && multiples_type && multiples_type.hasStaticShape() && + (input_type.getRank() != multiples_type.getNumElements() || + (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) { + return op.emitOpError() + << "expected size of multiples equal to rank of input" + << ", got multiples of size " << multiples_type.getNumElements() + << ", and input of rank " << input_type.getRank(); + } + + if (input_type && output_type) { + if (input_type.getRank() != output_type.getRank()) { + return op.emitOpError() + << "expected rank of input to equal to rank of output" + << ", got input of rank " << input_type.getRank() + << ", and output of rank " << output_type.getRank(); + } + + DenseIntElementsAttr multiples_attr; + if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) { + for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) { + const int64_t input_dim = input_type.getDimSize(i); + const int64_t output_dim = output_type.getDimSize(i); + const int64_t m = multiples_attr.getValue(i).getSExtValue(); + + if (m < 0) { + return op.emitOpError() + << "expected multiples to be non-negative, got " + << "multiples[" << i << "] = " << m; + } + + if (!ShapedType::isDynamic(input_dim) && + !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) { + return op.emitOpError() + << "requires input.shape[" << i << "] (" << input_dim << ")" + << " * " << m << " to be equal to " + << "output.shape[" << i << "] (" << output_dim << ")"; + } + } + } + } + + return success(); +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + DenseIntElementsAttr multiples_attr; + if (matchPattern(multiples(), m_Constant(&multiples_attr))) { + // Return input directly when multiples are all ones, + // regardless what input is. + if (multiples_attr.isSplat() && + multiples_attr.getSplatValue().getSExtValue() == 1) { + return input(); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // TopKV2Op //===----------------------------------------------------------------------===// @@ -1732,26 +1913,57 @@ static LogicalResult Verify(TopKV2Op op) { //===----------------------------------------------------------------------===// namespace { -// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity -// function and can be removed. -class ToBoolOfZeroDBoolTensor : public OpRewritePattern { +// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded +// into an identity or an equality comparison. +class ToBoolOfRankedTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToBoolOp op, PatternRewriter &rewriter) const override { - if (auto type = op.getOperand().getType().dyn_cast()) { - if (type.getRank() == 0 && type.getElementType().isInteger(1)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } + auto type = op.getOperand().getType().dyn_cast(); + // If the input is an unranked tensor, cannpt rewrite. + if (!type) return failure(); + + // Expected return type of the ToBool operation. + auto result_type = op.getResult().getType().cast(); + + // If input is already a tensor, it can be folded into an identity. + if (type == result_type) { + rewriter.replaceOp(op, op.getOperand()); + return success(); } - return failure(); + + if (type.getRank() == 0) { + // If the input is a scalar tensor, the ToBool can be expanded to + // element != 0 (for numerical values) or element == empty (for string). + Type element_type = type.getElementType(); + Attribute zero_attr; + if (element_type.isIntOrFloat()) + zero_attr = rewriter.getZeroAttr(type); + else if (element_type.isa()) + zero_attr = DenseStringElementsAttr::get(type, {""}); + + if (!zero_attr) return failure(); + + auto zero_const = rewriter.create(op.getLoc(), zero_attr); + rewriter.replaceOpWithNewOp( + op, result_type, op.getOperand(), zero_const, false); + } else { + // If the input is a non-scalar ranked tensor, ToBool can be expanded + // to numElements != 0. numElements will be 0 iff one of the dimensions is + // zero. + bool any_zero = + llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; }); + rewriter.replaceOpWithNewOp( + op, result_type, DenseElementsAttr::get(result_type, {!any_zero})); + } + return success(); } }; } // namespace void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1844,11 +2056,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, namespace { OpFoldResult FoldIdentityTranspose(TransposeOp op) { - auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); - if (!const_perm) return {}; - - auto const_value = const_perm.value(); - const auto elements = const_value.getValues(); + DenseIntElementsAttr perm; + if (!matchPattern(op.perm(), m_Constant(&perm))) return {}; + const auto elements = perm.getValues(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; @@ -1871,14 +2081,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { if (!transpose) return {}; // Permutations defined by constant operations. - auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); - auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); - if (!perm0 || !perm1) return {}; + DenseIntElementsAttr perm0; + DenseIntElementsAttr perm1; + if (!matchPattern(op.perm(), m_Constant(&perm0)) || + !matchPattern(transpose.perm(), m_Constant(&perm1))) + return {}; // With permutation indices that cancel each other - auto perm0_value = perm0.value().cast(); - auto perm1_value = perm1.value().cast(); - if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + if (!AreCancellablePermutations(perm0, perm1)) return {}; return transpose.x(); } @@ -1909,7 +2119,7 @@ static LogicalResult Verify(UnpackOp op) { if (!value_type) return success(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < -value_rank || axis >= value_rank) return op.emitOpError("axis attribute must be in the range of [-") << value_rank << ", " << value_rank << ')'; @@ -2029,38 +2239,19 @@ OpFoldResult VariableShapeOp::fold(ArrayRef operands) { // WhileOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(WhileOp op) { - auto cond_fn = op.cond_func(); - auto body_fn = op.body_func(); - if (!cond_fn) { - return op.emitOpError("cond refers to an undefined function : ") - << op.cond(); - } - if (!body_fn) { - return op.emitOpError("body refers to an undefined function : ") - << op.body(); - } - - auto cond_fn_type = cond_fn.getType(); - auto body_fn_type = body_fn.getType(); - - // Verify that the cond function has exactly one result. - if (cond_fn_type.getNumResults() != 1) - return op.emitOpError("requires cond function to have exactly one result"); - - SmallVector operands(op.getOperandTypes()); - +static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input, + TypeRange body_input, + TypeRange body_result) { // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. constexpr int kNumTypeLists = 5; - const std::array>, kNumTypeLists> - type_lists = {{ - {"operand", operands}, - {"body function result", body_fn_type.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", cond_fn_type.getInputs()}, - {"body function input", body_fn_type.getInputs()}, - }}; + const std::array type_lists = {{ + {op->getOperandTypes(), "input"}, + {body_result, "body result"}, + {op->getResultTypes(), "result"}, + {cond_input, "condition input"}, + {body_input, "body input"}, + }}; // A pair of type lists should be cast compatible with each other if one is // converted to the another for a function call or assignment or there is a @@ -2090,28 +2281,38 @@ static LogicalResult Verify(WhileOp op) { for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { auto &a = type_lists[i]; auto &b = type_lists[j]; - - int a_size = a.second.size(); - if (a_size != b.second.size()) - return op.emitOpError( - llvm::formatv("requires the number of {0}s to be equal to the " - "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, a_size, b.second.size())); - - for (int idx = 0; idx < a_size; ++idx) { - auto a_type = a.second[idx]; - auto b_type = b.second[idx]; - - if (!AreCastCompatible({a_type, b_type})) - return op.emitError(llvm::formatv( - "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, a_type, b.first, b_type, idx)); - } + if (failed(VerifyTypeRangesAreCompatible(op, a, b))) return failure(); } } return success(); } +static LogicalResult Verify(WhileOp op) { + auto cond_fn = op.cond_function(); + auto body_fn = op.body_function(); + if (!cond_fn) { + return op.emitOpError("cond refers to an undefined function : ") + << op.cond(); + } + if (!body_fn) { + return op.emitOpError("body refers to an undefined function : ") + << op.body(); + } + + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); + + // Verify that the cond function has exactly one result. + if (cond_fn_type.getNumResults() != 1) + return op.emitOpError("requires cond function to have exactly one result"); + + if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(), + /*body_input=*/body_fn_type.getInputs(), + /*body_result=*/body_fn_type.getResults()))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // WhileOp canonicalization. //===----------------------------------------------------------------------===// @@ -2125,50 +2326,23 @@ void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult Verify(WhileRegionOp op) { // Verify that the condition generates a single tensor result. - YieldOp yield = cast(op.cond().front().getTerminator()); - if (yield.getNumOperands() != 1) + Operation *cond_yield = op.cond().front().getTerminator(); + if (cond_yield->getNumOperands() != 1) return op.emitOpError() << "condition should have a single tensor result"; - auto cond_type = yield.getOperand(0).getType().dyn_cast(); + auto cond_type = + cond_yield->getOperand(0).getType().dyn_cast(); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return op.emitOpError() << "condition should have a single tensor result"; - // The body result types should match while op result types. - if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); - - // Both condition and body should have same number and type of operands as - // the WhileRegion inputs. - const int num_inputs = op.getNumOperands(); - auto block_inputs_match_op_inputs = [&](Region ®ion, - StringRef name) -> LogicalResult { - Block &block = region.front(); - if (block.getNumArguments() != num_inputs) - return op.emitOpError() - << name << " should have same number of inputs (" << num_inputs - << ") as " << WhileRegionOp::getOperationName() << " but has " - << block.getNumArguments() << " inputs"; - - for (auto types_idx : llvm::enumerate( - llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { - auto op_input_type = std::get<0>(types_idx.value()); - auto block_input_type = std::get<1>(types_idx.value()); - if (!AreCastCompatible({block_input_type, op_input_type})) - return op.emitOpError(llvm::formatv( - "{0} input type {1} is incompatible with {2} " - "input type {3} at index {4}", - name, block_input_type, WhileRegionOp::getOperationName(), - op_input_type, types_idx.index())); - } - return success(); - }; - - if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || - failed(block_inputs_match_op_inputs(op.body(), "body"))) + Operation *body_yield = op.body().front().getTerminator(); + if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(), + /*body_input=*/op.body().getArgumentTypes(), + /*body_result=*/body_yield->getOperandTypes()))) return failure(); - return success(); } @@ -2280,7 +2454,8 @@ struct WhileRegionEliminatePassThrough auto &new_body_block = new_while_op.body().front(); auto &new_yield = *new_body_block.getTerminator(); - // Build a vector of new results. Also patch up the region bodies and yield. + // Build a vector of new results. Also patch up the region bodies and + // yield. SmallVector new_results; next_idx = 0; for (int op_idx : llvm::seq(0, old_num_operands)) { @@ -2315,12 +2490,12 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +} // namespace TF +} // namespace mlir + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h index 761c06a475c..9b06d855b01 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h @@ -38,15 +38,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" -namespace mlir { -namespace TF { - #define GET_OP_FWD_DEFINES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index e87cc494a4a..38f9175a500 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -70,11 +70,12 @@ limitations under the License. namespace mlir { namespace TF { - namespace { #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace +} // namespace TF +} // namespace mlir //===----------------------------------------------------------------------===// // TableGen'd op method definitions @@ -82,6 +83,3 @@ namespace { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h index 8586515edee..589e0e91615 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h @@ -36,15 +36,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" -namespace mlir { -namespace TF { - #define GET_OP_FWD_DEFINES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc" -} // namespace TF -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 94a792ec3db..1eaf997ab69 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -105,15 +105,27 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { return success(); } +} // namespace tf_saved_model +} // namespace mlir + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" +namespace mlir { +namespace tf_saved_model { + //===----------------------------------------------------------------------===// // TensorFlowSavedModelDialect Dialect //===----------------------------------------------------------------------===// TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) - : Dialect(/*name=*/"tf_saved_model", context) { + : Dialect(/*name=*/"tf_saved_model", context, + TypeID::get()) { + // The TensorFlow Dialect is needed in the verifier and other routines + // associated to this dialect. It makes little sense anyway to use the + // SavedModel dialect without the TensorFlow Dialect. + context->loadDialect(); + addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 02b7f0b75f4..c8518a9ca02 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -40,10 +40,16 @@ class TensorFlowSavedModelDialect : public Dialect { static StringRef getDialectNamespace() { return "tf_saved_model"; } }; +} // namespace tf_saved_model +} // namespace mlir + // Declares the operations for this dialect using the generated header. #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h.inc" +namespace mlir { +namespace tf_saved_model { + // Returns the list of exported names for `op`. // An empty list means `op` is not exported. SmallVector GetExportedNames(Operation *op); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index a22a684953b..753e2368d6e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -82,7 +82,7 @@ def TfSavedModel_Dialect : Dialect { with "get_global @some_global_tensor" in the function body. }]; - let cppNamespace = "tf_saved_model"; + let cppNamespace = "::mlir::tf_saved_model"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 9be61b1db39..8dc5ffb5d09 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -35,6 +35,14 @@ struct TensorArray : ::mlir::SideEffects::Resource::Base { StringRef getName() final { return "TensorArray"; } }; +struct Summary : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Summary"; } +}; + +struct LookupTable : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "LookupTable"; } +}; + } // namespace ResourceEffects } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc index 6c5485c16dd..9d8f25c6633 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc @@ -15,11 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" -namespace mlir { - -// NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc.inc" +namespace mlir { namespace TF { void RuntimeDevices::AddDevice(const ParsedName& device) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h index b1f39ad1d28..b90bf2d47a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h @@ -26,10 +26,9 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/core/util/device_name_utils.h" -namespace mlir { - #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h.inc" +namespace mlir { namespace TF { // Tensorflow devices available at runtime with corresponding metadata if it is diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index fc8e6f40f65..412bf113a0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -33,7 +33,7 @@ namespace TF { static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, mlir::Type maybe_ref_type) { if (auto ref_type = maybe_ref_type.dyn_cast()) - return success(ref_type.RemoveRef().getKind() == type.getKind()); + return success(ref_type.RemoveRef().getTypeID() == type.getTypeID()); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 994378ea1cf..50f034e8ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -61,6 +62,192 @@ bool GetCastCompatibleShape(llvm::ArrayRef a_shape, return true; } +} // namespace + +namespace mlir { +namespace TF { +//===----------------------------------------------------------------------===// +// Utility iterators +//===----------------------------------------------------------------------===// + +OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it) + : llvm::mapped_iterator> (*)(Value)>( + it, &GetShape) {} + +ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it) + : llvm::mapped_iterator> (*)(Value)>( + it, &GetShape) {} + +//===----------------------------------------------------------------------===// +// TF types helper functions +//===----------------------------------------------------------------------===// + +bool TensorFlowType::classof(Type type) { + return type.getDialect().getNamespace() == "tf"; +} +bool TensorFlowRefType::classof(Type type) { + return type.isa< +#define HANDLE_TF_TYPE(tftype, enumerant, name) +#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type, +#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + >(); +} +bool TensorFlowTypeWithSubtype::classof(Type type) { + return type.isa(); +} + +TensorFlowType TensorFlowRefType::get(Type type) { + MLIRContext* ctx = type.getContext(); + type = getElementTypeOrSelf(type); + if (type.isF16()) { + return HalfRefType::get(ctx); + } else if (type.isF32()) { + return FloatRefType::get(ctx); + } else if (type.isF64()) { + return DoubleRefType::get(ctx); + } else if (type.isBF16()) { + return Bfloat16RefType::get(ctx); + } else if (auto complex_type = type.dyn_cast()) { + Type etype = complex_type.getElementType(); + if (etype.isF32()) { + return Complex64RefType::get(ctx); + } else if (etype.isF64()) { + return Complex128RefType::get(ctx); + } + llvm_unreachable("unexpected complex type"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return BoolRefType::get(ctx); + case 8: + return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) + : Int8RefType::get(ctx); + case 16: + return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) + : Int16RefType::get(ctx); + case 32: + return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) + : Int32RefType::get(ctx); + case 64: + return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) + : Int64RefType::get(ctx); + default: + llvm_unreachable("unexpected integer type"); + } + } +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = type.dyn_cast()) \ + return tftype##RefType::get(ctx); + +#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + llvm_unreachable("unexpected type kind"); +} + +Type TensorFlowRefType::RemoveRef() { + MLIRContext* ctx = getContext(); + if (isa()) return mlir::FloatType::getF16(ctx); + if (isa()) return mlir::FloatType::getF32(ctx); + if (isa()) return mlir::FloatType::getF64(ctx); + if (isa()) return mlir::FloatType::getBF16(ctx); + if (isa()) return mlir::IntegerType::get(1, ctx); + if (isa()) return mlir::IntegerType::get(8, ctx); + if (isa()) return mlir::IntegerType::get(16, ctx); + if (isa()) return mlir::IntegerType::get(32, ctx); + if (isa()) return mlir::IntegerType::get(64, ctx); + if (isa()) + return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); + if (isa()) + return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (isa()) return tftype##Type::get(ctx); + +#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + llvm_unreachable("unexpected tensorflow ref type kind"); +} + +Type TensorFlowTypeWithSubtype::RemoveSubtypes() { + MLIRContext* ctx = getContext(); + if (isa()) return VariantType::get(ctx); + if (isa()) return ResourceType::get(ctx); + llvm_unreachable("unexpected tensorflow type with subtypes kind"); +} + +ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { + if (auto variant_type = dyn_cast()) + return variant_type.getSubtypes(); + if (auto resource_type = dyn_cast()) + return resource_type.getSubtypes(); + llvm_unreachable("unexpected tensorflow type with subtypes kind"); +} + +// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have +// similar structure that could be extracted into helper method. +bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) return false; + for (auto types : llvm::zip(lhs, rhs)) { + // Drop ref types because they don't affect broadcast compatibility. E.g., + // `tensor` and `tensor` should be considered broadcast + // compatible. + auto lhs_type = DropRefType(std::get<0>(types)); + auto rhs_type = DropRefType(std::get<1>(types)); + + // This should be true for all TF ops: + auto lhs_tt = lhs_type.dyn_cast(); + auto rhs_tt = rhs_type.dyn_cast(); + if (!lhs_tt || !rhs_tt) { + if (lhs_type != rhs_type) return false; + continue; + } + + // Verify matching element types. These should be identical, except for + // variant type where unknown subtype is considered compatible with all + // subtypes. + auto lhs_et = lhs_tt.getElementType(); + auto rhs_et = rhs_tt.getElementType(); + if (lhs_et != rhs_et) { + // If either does not have subtypes, then the element types don't match. + auto lhs_wst = lhs_et.dyn_cast(); + auto rhs_wst = rhs_et.dyn_cast(); + if (!lhs_wst || !rhs_wst) return false; + + // Consider the subtype of variant types. + auto lhs_wst_st = lhs_wst.GetSubtypes(); + auto rhs_wst_st = rhs_wst.GetSubtypes(); + if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) { + for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { + if (!BroadcastCompatible(std::get<0>(subtypes), + std::get<1>(subtypes))) + return false; + } + } + } + + auto lhs_rt = lhs_type.dyn_cast(); + auto rhs_rt = rhs_type.dyn_cast(); + if (!lhs_rt || !rhs_rt) return true; + SmallVector shape; + return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), + rhs_rt.getShape(), shape); + } + return true; +} + // Given two types `a` and `b`, returns a refined type which is cast compatible // with both `a` and `b` and is equal to or more precise than both of them. It // returns empty Type if the input types are not cast compatible. @@ -100,7 +287,7 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, if (a == b) return a; } } - if (a.getKind() != b.getKind()) return nullptr; + if (a.getTypeID() != b.getTypeID()) return nullptr; // If either is not a type that contain subtypes then the types are not cast // compatible. @@ -156,199 +343,6 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, return mlir::RankedTensorType::get(refined_shape, refined_element_ty); } -} // namespace - -namespace mlir { -namespace TF { -//===----------------------------------------------------------------------===// -// Utility iterators -//===----------------------------------------------------------------------===// - -OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it) - : llvm::mapped_iterator> (*)(Value)>( - it, &GetShape) {} - -ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it) - : llvm::mapped_iterator> (*)(Value)>( - it, &GetShape) {} - -//===----------------------------------------------------------------------===// -// TF types helper functions -//===----------------------------------------------------------------------===// - -TensorFlowType TensorFlowRefType::get(Type type) { - MLIRContext* ctx = type.getContext(); - switch (getElementTypeOrSelf(type).getKind()) { - case StandardTypes::F16: - return HalfRefType::get(ctx); - case StandardTypes::F32: - return FloatRefType::get(ctx); - case StandardTypes::F64: - return DoubleRefType::get(ctx); - case StandardTypes::BF16: - return Bfloat16RefType::get(ctx); - case StandardTypes::Complex: { - const auto& etype = type.cast().getElementType(); - switch (getElementTypeOrSelf(etype).getKind()) { - case StandardTypes::F32: - return Complex64RefType::get(ctx); - case StandardTypes::F64: - return Complex128RefType::get(ctx); - default: - llvm_unreachable("unexpected complex type"); - } - } - case StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return BoolRefType::get(ctx); - case 8: - return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) - : Int8RefType::get(ctx); - case 16: - return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) - : Int16RefType::get(ctx); - case 32: - return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) - : Int32RefType::get(ctx); - case 64: - return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) - : Int64RefType::get(ctx); - default: - llvm_unreachable("unexpected integer type"); - } - } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - return tftype##RefType::get(ctx); - -#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) -// NOLINTNEXTLINE -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - llvm_unreachable("unexpected type kind"); - } -} - -Type TensorFlowRefType::RemoveRef() { - MLIRContext* ctx = getContext(); - switch (getKind()) { - case TensorFlowTypes::HALF_REF: - return mlir::FloatType::getF16(ctx); - case TensorFlowTypes::FLOAT_REF: - return mlir::FloatType::getF32(ctx); - case TensorFlowTypes::DOUBLE_REF: - return mlir::FloatType::getF64(ctx); - case TensorFlowTypes::BFLOAT16_REF: - return mlir::FloatType::getBF16(ctx); - case TensorFlowTypes::BOOL_REF: - return mlir::IntegerType::get(1, ctx); - case TensorFlowTypes::INT8_REF: - return mlir::IntegerType::get(8, ctx); - case TensorFlowTypes::INT16_REF: - return mlir::IntegerType::get(16, ctx); - case TensorFlowTypes::INT32_REF: - return mlir::IntegerType::get(32, ctx); - case TensorFlowTypes::INT64_REF: - return mlir::IntegerType::get(64, ctx); - case TensorFlowTypes::UINT8_REF: - return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT16_REF: - return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT32_REF: - return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT64_REF: - return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx); - case TensorFlowTypes::COMPLEX64_REF: - return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); - case TensorFlowTypes::COMPLEX128_REF: - return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant##_REF: \ - return tftype##Type::get(ctx); - -#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) -// NOLINTNEXTLINE -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - llvm_unreachable("unexpected tensorflow ref type kind"); - } -} - -Type TensorFlowTypeWithSubtype::RemoveSubtypes() { - MLIRContext* ctx = getContext(); - switch (getKind()) { - case TensorFlowTypes::VARIANT: - return VariantType::get(ctx); - case TensorFlowTypes::RESOURCE: - return ResourceType::get(ctx); - default: - llvm_unreachable("unexpected tensorflow type with subtypes kind"); - } -} - -ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { - switch (getKind()) { - case TensorFlowTypes::VARIANT: - return this->cast().getSubtypes(); - case TensorFlowTypes::RESOURCE: - return this->cast().getSubtypes(); - default: - llvm_unreachable("unexpected tensorflow type with subtypes kind"); - } -} - -// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have -// similar structure that could be extracted into helper method. -bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { - if (lhs.size() != rhs.size()) return false; - for (auto types : llvm::zip(lhs, rhs)) { - auto lhs_type = std::get<0>(types); - auto rhs_type = std::get<1>(types); - - // This should be true for all TF ops: - auto lhs_tt = lhs_type.dyn_cast(); - auto rhs_tt = rhs_type.dyn_cast(); - if (!lhs_tt || !rhs_tt) { - if (lhs_type != rhs_type) return false; - continue; - } - - // Verify matching element types. These should be identical, except for - // variant type where unknown subtype is considered compatible with all - // subtypes. - auto lhs_et = lhs_tt.getElementType(); - auto rhs_et = rhs_tt.getElementType(); - if (lhs_et != rhs_et) { - // If either does not have subtypes, then the element types don't match. - auto lhs_wst = lhs_et.dyn_cast(); - auto rhs_wst = rhs_et.dyn_cast(); - if (!lhs_wst || !rhs_wst) return false; - - // Consider the subtype of variant types. - auto lhs_wst_st = lhs_wst.GetSubtypes(); - auto rhs_wst_st = rhs_wst.GetSubtypes(); - if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) { - for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { - if (!BroadcastCompatible(std::get<0>(subtypes), - std::get<1>(subtypes))) - return false; - } - } - } - - auto lhs_rt = lhs_type.dyn_cast(); - auto rhs_rt = rhs_type.dyn_cast(); - if (!lhs_rt || !rhs_rt) return true; - SmallVector shape; - return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), - rhs_rt.getShape(), shape); - } - return true; -} bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { @@ -366,27 +360,31 @@ bool AreCastCompatible(ArrayRef types) { return true; } -ShapedType DropTypeSubTypes(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto subtype_ty = element_ty.dyn_cast(); - if (!subtype_ty) return ty; +// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default +// type for a composed type (such as a ref type or a type with subtypes). +template +Type DropTypeHelper(Type ty) { + Type element_ty = getElementTypeOrSelf(ty); + auto composed_type = element_ty.dyn_cast(); + if (!composed_type) return ty; - Type default_ty = GetDefaultTypeOf(subtype_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); + Type default_ty = GetDefaultTypeOf(composed_type); + if (auto ranked_ty = ty.dyn_cast()) { + return RankedTensorType::get(ranked_ty.getShape(), default_ty); + } else if (ty.dyn_cast()) { + return UnrankedTensorType::get(default_ty); + } else { + return default_ty; + } } -ShapedType DropRefType(ShapedType ty) { - Type element_ty = ty.getElementType(); - TF::TensorFlowRefType ref_ty = element_ty.dyn_cast(); - if (!ref_ty) return ty; - - Type default_ty = TF::GetDefaultTypeOf(ref_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); +Type DropSubTypes(Type ty) { + return DropTypeHelper(ty); } +Type DropRefType(Type ty) { return DropTypeHelper(ty); } + +Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); } + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 125f6bb31df..60a86f32920 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -67,26 +67,13 @@ using ResultShapeRange = iterator_range; // TensorFlow types //===----------------------------------------------------------------------===// -namespace TensorFlowTypes { -// List of supported TensorFlowType kinds, necessary for isa/dyn_cast. -enum Kind { - FIRST_USED_TENSORFLOW_TYPE = Type::FIRST_TENSORFLOW_TYPE, -#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant, -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - LAST_USED_TENSORFLOW_TYPE, -}; -} // namespace TensorFlowTypes - // The base class in the TensorFlow type hierarchy. class TensorFlowType : public Type { public: using Type::Type; // Support method to enable LLVM-style type casting. - static bool classof(Type type) { - return type.getKind() >= Type::FIRST_TENSORFLOW_TYPE && - type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE; - } + static bool classof(Type type); }; // Returns true if the specified type is a valid TensorFlow element type. @@ -105,10 +92,7 @@ static inline bool IsValidTFTensorType(Type type) { namespace detail { // Common implementation of TensorFlow types. The template argument indicates -// the concrete derived class per CRTP. Concrete classes must implement the -// following: -// - `static unsigned getTypeKind()` that returns the (fixed) kind of the -// type. +// the concrete derived class per CRTP. template class TensorFlowTypeImpl : public Type::TypeBase { @@ -116,14 +100,6 @@ class TensorFlowTypeImpl using Base = typename Type::TypeBase; using TFBase = TensorFlowTypeImpl; using Base::Base; - - // Get the unique'ed type in the given context. - static Derived get(MLIRContext* context) { - return Base::get(context, Derived::getTypeKind()); - } - - // Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == Derived::getTypeKind(); } }; } // namespace detail @@ -133,10 +109,7 @@ class TensorFlowRefType : public TensorFlowType { using TensorFlowType::TensorFlowType; // Checks if a type is TensorFlow Ref type. - static bool classof(Type type) { - return type.getKind() >= TensorFlowTypes::FLOAT_REF && - type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE; - } + static bool classof(Type type); // Converts a type to the corresponding TensorFlowRef type. static TensorFlowType get(Type type); @@ -182,7 +155,6 @@ static inline Type GetElementTypeOrSelfResolveRef(Type type) { class tftype##Type : public detail::TensorFlowTypeImpl { \ public: \ using TFBase::TFBase; \ - static unsigned getTypeKind() { return TensorFlowTypes::enumerant; } \ }; // Custom TensorFlow types are defined separately. @@ -220,8 +192,6 @@ class TypeWithSubtypeStorage : public TypeStorage { // opaque and their interpretation depends on the actual underlying type. // The template argument indicates the concrete derived class per CRTP. Concrete // classes must implement the following: -// - `static unsigned getTypeKind()` that returns the (fixed) kind of the -// type. // - `static std::string getTypeName()` that returns the name of the type for // verification logging. template @@ -233,19 +203,16 @@ class TypeWithSubtypeImpl using Base::Base; static Derived get(ArrayRef subtypes, MLIRContext* context) { - return Base::get(context, Derived::getTypeKind(), subtypes); + return Base::get(context, subtypes); } static Derived getChecked(ArrayRef subtypes, MLIRContext* context, Location loc) { - return Base::getChecked(loc, Derived::getTypeKind(), subtypes); + return Base::getChecked(loc, subtypes); } static Derived get(MLIRContext* context) { return get({}, context); } - // Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == Derived::getTypeKind(); } - static LogicalResult verifyConstructionInvariants( Location loc, ArrayRef subtypes) { // Each of the subtypes should be a valid TensorFlow type. @@ -269,10 +236,7 @@ class TensorFlowTypeWithSubtype : public TensorFlowType { using TensorFlowType::TensorFlowType; // Checks if a type is TensorFlow type with subtypes. - static bool classof(Type type) { - return type.getKind() == TensorFlowTypes::VARIANT || - type.getKind() == TensorFlowTypes::RESOURCE; - } + static bool classof(Type type); // Converts a TypeWithSubtype type to the same type but without its subtypes. Type RemoveSubtypes(); @@ -294,7 +258,6 @@ static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) { class ResourceType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; - static unsigned getTypeKind() { return TensorFlowTypes::RESOURCE; } static std::string getTypeName() { return "ResourceType"; } }; @@ -306,10 +269,18 @@ class ResourceType : public detail::TypeWithSubtypeImpl { class VariantType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; - static unsigned getTypeKind() { return TensorFlowTypes::VARIANT; } static std::string getTypeName() { return "VariantType"; } }; +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a); + // Returns whether two arrays of Type are broadcast compatible. bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); @@ -331,15 +302,21 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef types); -// If the given tensor has elements of type with subtypes, then returns a new -// type after dropping subtypes info. Otherwise, returns the original type as -// is. -ShapedType DropTypeSubTypes(ShapedType ty); +// If `ty` is a tensor type and its element type has subtypes, then returns a +// new type of same shape but dropped subtypes for the element type. +// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped +// subtypes. +// Otherwise, returns the original type `ty`. +Type DropSubTypes(Type ty); -// If the given tensor has elements of type ref, then returns a new type -// of the shape, but corresponding non-ref type as element type. Otherwise, -// returns the original type as is. -ShapedType DropRefType(ShapedType ty); +// If `ty` is a tensor type and has elements of a ref type, then returns a new +// type of same shape but corresponding non-ref type as element type. +// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type. +// Otherwise, returns the original type `ty`. +Type DropRefType(Type ty); + +// Convenience call for executing both `DropRefType` and `DropSubTypes`. +Type DropRefAndSubTypes(Type ty); } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc new file mode 100644 index 00000000000..6a6a7574f29 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h similarity index 57% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc rename to tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h index 9d1c354690a..039f211533c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ -// Static initialization for *HLO dialects registration. -static mlir::DialectRegistration mhlo_ops; -static mlir::DialectRegistration chlo_ops; -static mlir::DialectRegistration lmhlo_ops; +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td new file mode 100644 index 00000000000..fea9500b638 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for TensorFlow operations with +// implementation available only in TFRT. + +#ifndef TFRT_OPS +#define TFRT_OPS + +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "mlir/IR/OpBase.td" + +def TF__JitFusedMatMulOp : TF_Op<"_JitFusedMatMul", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = [{ + MatMul operation with an output fusion compiled at runtime via MLIR codegen. + }]; + + let description = [{ +The inputs to the MatMul are specified by `a` and `b`. The series of operations +that follows is specified by the `fusion` attribute, which is a list of output +kernel names specified as strings (e.g. "BiasAdd"). They are performed in order, +where the (first) input to each op is the output of the preceding op. The first +input and the output of each fused_op must be of type T. + +Supported list of fusions is defined by ContractionOutputKernelBuilder +implementations. + +*WARN*: This is a TFRT only operations, and it does not exist in TF. This +operation is only added by the ContractionFusion pass. + }]; + + let arguments = (ins + TensorOf<[F32]>:$a, + TensorOf<[F32]>:$b, + Variadic>:$additional_args, + + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$fusion + ); + + let results = (outs + TensorOf<[F32]>:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +#endif // TFRT_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc b/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc index 211866900aa..d2c2cecdfdd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/ops/mlir_local_var_op.cc @@ -21,7 +21,7 @@ namespace tensorflow { REGISTER_OP("MlirLocalVarOp") .Output("resource: resource") .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"(Creates a handle to a in-scope variable. + .Doc(R"(Creates a handle to an in-scope variable. Used by internal passes for temporary representation of local state, which will be eventually removed.)"); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index 05d34eb0755..6654341ab42 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi // and certain tf_executor ops are added correctly. // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" -// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] +// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]] func @next_iteration_sink_control_input() { tf_executor.graph { %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 595bdce5be4..ff90c6f4c5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -444,6 +444,14 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x return %0 : tensor<2x4xf32> } +// CHECK-LABEL: func @testBroadcastToNoOp +func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32> + + // CHECK: return %arg0 + return %0 : tensor<2x4xf32> +} + // CHECK-LABEL: func @testPackShapeComputation func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. @@ -560,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: return %0: tensor<*xf16> } +// CHECK-LABEL: testTileMultiplesAllOnes +func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense <[1, 1]> : tensor<2xi32> + // CHECK: return %arg0 + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return %0: tensor<2x3xf32> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> @@ -620,6 +636,15 @@ func @testLogicalNotOfLessEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32> // CHECK: return %0 } +// CHECK-LABEL: testSizeFolding +func @testSizeFolding(%arg0: tensor<3x5x7xf32>) -> tensor { + %0 = "tf.Size"(%arg0) : (tensor<3x5x7xf32>) -> tensor + return %0: tensor + +// CHECK: %0 = "tf.Const"() {value = dense<105> : tensor} : () -> tensor +// CHECK: return %0 : tensor +} + // CHECK-LABEL: testDivWithSqrtDivisor func @testDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -685,6 +710,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { // CHECK: return %arg0 } +// CHECK-LABEL: @identityTransposeConst +func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { + %0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32> + + return %1 : tensor<2x3x4x5x6xf32> + // CHECK: return %arg0 +} + // CHECK-LABEL: @nonIdentityTranspose func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> { %0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> @@ -707,6 +741,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { // CHECK: return %arg0 } +// CHECK-LABEL: @cancellableTransposeConst +func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32> + %1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + + return %3 : tensor<1x4x4x8xf32> + // CHECK: return %arg0 +} + // CHECK-LABEL: @nonCancellableTranspose func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> { %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -725,13 +770,72 @@ func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @ToBool_0DScalar -func @ToBool_0DScalar(%arg0: tensor) -> tensor { +// CHECK-LABEL: func @ToBool_0DScalarI1 +func @ToBool_0DScalarI1(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @ToBool_0DScalarInt +func @ToBool_0DScalarInt(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarFloat +func @ToBool_0DScalarFloat(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarString +func @ToBool_0DScalarString(%arg0: tensor) -> tensor { + // CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor, tensor) -> tensor + // CHECK: return [[NE]] : tensor + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensor +func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensorZeroDim +func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensor +func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensorZeroDim +func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor + return %0 : tensor +} + // CHECK-LABEL: testReadVariableOpOfCast func @testReadVariableOpOfCast(%arg0: tensor>>) -> tensor<8x40xf32> { %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor<*x!tf.resource> @@ -826,6 +930,51 @@ func @foldIf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tens return %4 : tensor } +// CHECK-LABEL: foldIfRegion +func @foldIfRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + %false = "tf.Const"() {value = dense : tensor} : () -> tensor + %true = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1) + %0 = "tf.IfRegion"(%true) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor + + // CHECK: [[Val1:%.*]] = "tf.Sub"(%arg0, %arg1) + %1 = "tf.IfRegion"(%false) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor + + // CHECK: return [[Val0]], [[Val1]] + return %0, %1 : tensor, tensor +} + +// CHECK-LABEL: foldIfRegionMismatchedTypes +func @foldIfRegionMismatchedTypes(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<1xf32> { + %false = "tf.Const"() {value = dense : tensor} : () -> tensor + %true = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1) + // CHECK-NEXT: [[Cast:%.*]] = "tf.Cast"([[Val0]]) + // CHECK-NEXT: return [[Cast]] + %0 = "tf.IfRegion"(%true) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor @@ -834,11 +983,11 @@ func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { // CHECK: PartitionedCall // CHECK-SAME: device = "noodle" // CHECK-SAME: f = @add - %4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle", is_stateless = false} : (tensor, tensor, tensor) -> tensor // CHECK: PartitionedCall // CHECK-SAME: _cluster_launch = "not_ready" // CHECK-SAME: f = @sub - %5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor, tensor, tensor) -> tensor + %5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready", is_stateless = false} : (tensor, tensor, tensor) -> tensor return %5 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD new file mode 100644 index 00000000000..6be08ac988c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD @@ -0,0 +1,25 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + "pbtxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir:tf-mlir-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir new file mode 100644 index 00000000000..84e3f528a5c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/add.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> (f32[]) { +// CHECK-NEXT: %[[ARG0]] = f32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[] parameter(1) +// CHECK-NEXT: [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] [[ADD]]) +// CHECK-NEXT: } + +// CHECK: // InputMapping {0, 1} +// CHECK-NEXT: // XlaInputShape f32[] +// CHECK-NEXT: // XlaInputShape f32[] +// CHECK-NEXT: // XlaOutputShape (f32[]) +// CHECK-NEXT: // XlaOutputDescription type=float shape=() + + +// TUPLE-ARGS-LABEL: HloModule main +// TUPLE-ARGS: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[], f32[])) -> (f32[]) { +// TUPLE-ARGS: %[[ARG_TUPLE]] = (f32[], f32[]) parameter(0) +// TUPLE-ARGS: [[ARG0:%.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG_TUPLE]]), index=0 +// TUPLE-ARGS: [[ARG1:%.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG_TUPLE]]), index=1 +// TUPLE-ARGS: [[ADD:%.*]] = f32[] add(f32[] [[ARG0]], f32[] [[ARG1]]) +// TUPLE-ARGS: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] [[ADD]]) +// TUPLE-ARGS: } + +// TUPLE-ARGS: // InputMapping {0, 1} +// TUPLE-ARGS-NEXT: // XlaInputShape (f32[], f32[]) +// TUPLE-ARGS-NEXT: // XlaOutputShape (f32[]) +// TUPLE-ARGS-NEXT: // XlaOutputDescription type=float shape=() diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir new file mode 100644 index 00000000000..5347037d7cf --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10 -emit-use-tuple-args -emit-return-tuple 2>&1 | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x8xf32> {mhlo.sharding = "bad_sharding"}) { + return + } +} + +// CHECK: failed to parse argument sharding 0 'bad_sharding' diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir new file mode 100644 index 00000000000..7154919c3d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/argument-sharding.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10:10,1024:128,1024 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) { + return + } +} + +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { +// CHECK: %[[ARG_TUPLE]] = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) +// CHECK-SAME: sharding={ +// CHECK-SAME: {devices=[1,2]0,1} +// CHECK-SAME: {maximal device=0} +// CHECK-SAME: {replicated} +// CHECK-SAME: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir new file mode 100644 index 00000000000..c745fbc0744 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding-hook.mlir @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main() -> (tensor<0xi32>, tensor<0xi32>) { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>) + return %r0, %r1 : tensor<0xi32>, tensor<0xi32> + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9+]}} ([[ARG_TUPLE:.*]]: ()) -> (s32[0], s32[0]) { +// CHECK: %[[ARG_TUPLE]] = () parameter(0) +// CHECK: [[CONSTANT:%.*]] = s32[0]{0} constant({}) +// CHECK: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]]) +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir new file mode 100644 index 00000000000..e54ff79e5e4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir @@ -0,0 +1,23 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,19:19,10 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { + %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> + %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> + return %1 : tensor<10x19xf32> + } +} + +// Tests that foldable ops are constant-folded to enable legalization of ops +// that require compile time constant operand. +// "tf.Shape" can only be folded away after shape inference. tf.Reshape can only +// be lowered when tf.Shape is folded into a constant. + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG_TUPLE:.*]]: (f32[10,19], f32[19,10])) -> (f32[10,19]) { +// CHECK: %[[ARG_TUPLE]] = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} +// CHECK: [[ARG0:%.*]] = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %[[ARG_TUPLE]]), index=0 +// CHECK: [[ARG1:%.*]] = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %[[ARG_TUPLE]]), index=1 +// CHECK: [[RESHAPE:%.*]] = f32[10,19]{1,0} reshape(f32[19,10]{1,0} [[ARG1]]) +// CHECK: ROOT %tuple.{{[0-9]+}} = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} [[RESHAPE]]) +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir new file mode 100644 index 00000000000..3d1a34b932d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.mlir @@ -0,0 +1,27 @@ +// RUN: tf-mlir-translate -mlir-tf-graph-to-hlo-text %s -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 511 : i32}} { + func @main(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) { + tf_executor.graph { + %control = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg0) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + tf_executor.fetch %control : !tf_executor.control + } + return + } +} + +// Tests a conversion from Graph (tf_executor dialect MLIR) to MLIR with +// resource arguments. + +// CHECK-LABEL: HloModule main.{{[0-9]+}}, input_output_alias={ {0}: (1, {}, may-alias) } +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[2], [[ARG1:.*]]: f32[2]) -> (f32[2]) { +// CHECK-NEXT: %[[ARG1]] = f32[2]{0} parameter(1) +// CHECK-NEXT: %[[ARG0]] = f32[2]{0} parameter(0) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[2]{0}) tuple(f32[2]{0} %[[ARG0]]) +// CHECK-NEXT: } + +// CHECK: // InputMapping {0, 1} +// CHECK-NEXT: // XlaInputShape f32[2] +// CHECK-NEXT: // XlaInputShape f32[2] +// CHECK-NEXT: // XlaOutputShape (f32[2]) +// CHECK-NEXT: // ResourceUpdate input_index=1 type=float shape=(2) modified diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt new file mode 100644 index 00000000000..5fb90b1bce0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt @@ -0,0 +1,66 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s + +node { + name: "arg0" + op: "_Arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "arg1" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "assign_variable" + op: "AssignVariableOp" + input: "arg1" + input: "arg0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { + producer: 511 +} + +# Tests a conversion from Graph to MLIR with resource arguments. + +# CHECK-LABEL: HloModule main.{{[0-9]+}}, input_output_alias={ {0}: (1, {}, may-alias) } +# CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[2], [[ARG1:.*]]: f32[2]) -> (f32[2]) { +# CHECK-NEXT: %[[ARG1]] = f32[2]{0} parameter(1) +# CHECK-NEXT: %[[ARG0]] = f32[2]{0} parameter(0) +# CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[2]{0}) tuple(f32[2]{0} %[[ARG0]]) +# CHECK-NEXT: } + +# CHECK: // InputMapping {0, 1} +# CHECK-NEXT: // XlaInputShape f32[2] +# CHECK-NEXT: // XlaInputShape f32[2] +# CHECK-NEXT: // XlaOutputShape (f32[2]) +# CHECK-NEXT: // ResourceUpdate input_index=1 type=float shape=(2) modified diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt new file mode 100644 index 00000000000..f1f7c6434eb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph.pbtxt @@ -0,0 +1,47 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes='' -tf-input-data-types=DT_FLOAT -emit-return-tuple | FileCheck %s + +node { + name: "arg" + op: "_Arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "retval" + op: "_Retval" + input: "arg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +versions { + producer: 511 +} + +# Verify that conversion from Graph to MLIR and empty shape representation +# function is successful. + +# CHECK-LABEL: HloModule main +# CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[]) -> (f32[]) { +# CHECK-NEXT: %[[ARG0]] = f32[] parameter(0) +# CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] %[[ARG0]]) +# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir new file mode 100644 index 00000000000..b68f177b183 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/mlir-module-serialized-str-attr.mlir @@ -0,0 +1,10 @@ +// RUN: tf-mlir-translate -mlir-tf-mlir-to-str-attr %s | FileCheck %s + +module attributes {tf.versions = {producer = 888 : i32}} { + func @main(%arg0: tensor) -> tensor { + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor loc(unknown) + return %0 : tensor loc(unknown) + } loc(unknown) +} loc(unknown) + +// CHECK: "\0A\0Amodule attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor) -> tensor {\0A %0 = \22tf.Identity\22(%arg0) : (tensor) -> tensor loc(unknown)\0A return %0 : tensor loc(unknown)\0A } loc(unknown)\0A} loc(unknown)" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir new file mode 100644 index 00000000000..c9c02ba2588 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/result-sharding.mlir @@ -0,0 +1,39 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=128,10:10,1024:128,1024 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { + func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) { + return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> + } +} + +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} +// CHECK-SAME: (arg_tuple.{{[0-9]+}}: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { +// CHECK: ROOT %tuple.{{[0-9]+}} +// CHECK-SAME: sharding={ +// CHECK-SAME: {devices=[1,2]0,1} +// CHECK-SAME: {maximal device=0} +// CHECK-SAME: {replicated} +// CHECK-SAME: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir new file mode 100644 index 00000000000..ced11f3a083 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr-invalid.mlir @@ -0,0 +1,5 @@ +// RUN: not tf-mlir-translate -mlir-tf-str-attr-to-mlir %s 2>&1 | FileCheck %s + +"totally @invalid MLIR module {here} <-" + +// CHECK: Invalid argument: could not parse MLIR module-:1:1: error: custom op 'totally' is unknown diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir new file mode 100644 index 00000000000..9a0e1dc38c8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/serialized-mlir-module-str-attr.mlir @@ -0,0 +1,15 @@ +// RUN: tf-mlir-translate -mlir-tf-str-attr-to-mlir %s -mlir-print-debuginfo | FileCheck %s + +"\0A\0Amodule attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor) -> tensor {\0A %0 = \22tf.Identity\22(%arg0) : (tensor) -> tensor loc(unknown)\0A return %0 : tensor loc(unknown)\0A } loc(unknown)\0A} loc(unknown)" + +// Test simple serialized computation consisting of a function named `main` +// with a tf.Identity op forwarding the function single argument to the function +// single result. + +// CHECK-LABEL: module +// CHECK-SAME: attributes {tf.versions = {producer = 888 : i32}} { +// CHECK-NEXT: func @main([[ARG0:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[IDENTITY:%.+]] = "tf.Identity"([[ARG0]]) : (tensor) -> tensor loc(unknown) +// CHECK-NEXT: return [[IDENTITY]] : tensor loc(unknown) +// CHECK-NEXT: } loc(unknown) +// CHECK-NEXT: } loc(unknown) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir new file mode 100644 index 00000000000..55bdea5dd36 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference-after-legalization.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) { + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32> + } +} + +// CHECK-LABEL: HloModule main +// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir new file mode 100644 index 00000000000..f9eca514da3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/shape-inference.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: HloModule main +// CHECK: (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index b86815dbe57..fff985efa6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -89,16 +89,45 @@ func @testEmptybf16() -> (tensor<5xbf16>) { } // CHECK-LABEL: func @testShapeN -func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor) { +func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>) -> (tensor<0xi64>, tensor<4xi64>) { - // CHECK: "tf.Const"() {value = dense<> : tensor<0xi64> - // CHECK: "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>} + // CHECK: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} + // CHECK: %[[SHAPE1:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>} %0:2 = "tf.ShapeN"(%arg0, %arg1) : (tensor, tensor<1x32x32x16xf32>) -> (tensor<0xi64>, tensor<4xi64>) - // CHECK: tf.ShapeN - %1:2 = "tf.ShapeN"(%arg1, %arg2) : (tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi64>, tensor) + // CHECK: return %[[SHAPE0]], %[[SHAPE1]] + return %0#0, %0#1 : tensor<0xi64>, tensor<4xi64> +} - return %0#0, %0#1, %1#0, %1#1 : tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor +// CHECK-LABEL: func @testShapeNPartialStatic +func @testShapeNPartialStatic(%arg0: tensor, %arg1: tensor<2x?x3xf32>, %arg2: tensor<1x32x32x16xf32>, %arg3: tensor<*xf32>) -> (tensor<0xi64>, tensor<3xi64>, tensor<4xi64>, tensor) { + // CHECK: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} + // CHECK: %[[SHAPE2:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>} + // CHECK: %[[SHAPE13:.*]]:2 = "tf.ShapeN"(%arg1, %arg3) : (tensor<2x?x3xf32>, tensor<*xf32>) -> (tensor<3xi64>, tensor) + %0:4 = "tf.ShapeN"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<2x?x3xf32>, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0xi64>, tensor<3xi64>, tensor<4xi64>, tensor) + + // CHECK: return %[[SHAPE0]], %[[SHAPE13]]#0, %[[SHAPE2]], %[[SHAPE13]]#1 + return %0#0, %0#1, %0#2, %0#3 : tensor<0xi64>, tensor<3xi64>, tensor<4xi64>, tensor +} + +// CHECK-LABEL: func @testShapeNOneDynamic +func @testShapeNOneDynamic(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor) { + // CHECK: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} + // CHECK: %[[SHAPE1:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>} + // CHECK: %[[SHAPE2:.*]] = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor + %0:3 = "tf.ShapeN"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor) + + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] + return %0#0, %0#1, %0#2 : tensor<0xi64>, tensor<4xi64>, tensor +} + +// CHECK-LABEL: func @testShapeNToShape +func @testShapeNToShape(%arg0: tensor<*xf32>) -> tensor { + // CHECK: %[[SHAPE0:.*]] = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + %0:1 = "tf.ShapeN"(%arg0) : (tensor<*xf32>) -> tensor + + // CHECK: return %[[SHAPE0]] + return %0#0 : tensor } // CHECK-LABEL: func @testLeakyRelu @@ -463,3 +492,13 @@ func @DontFoldTile() -> (tensor<8x10000xi32>) { return %3 : tensor<8x10000xi32> } // LINT.ThenChange(../transforms/constant_fold.cc:folding-policy) + +func @fold_conv() -> tensor<1x520x520x1xf32> { + %0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32> + %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32> + return %2 : tensor<1x520x520x1xf32> + + // CHECK: tf.Const + // CHECK-NOT: tf.DepthwiseConv2dNative +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir b/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir new file mode 100644 index 00000000000..b12f50ad525 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/contraction_fusion.mlir @@ -0,0 +1,37 @@ +// RUN: tf-opt %s -tf-contraction-fusion | FileCheck %s + +// CHECK-LABEL: matmulBiasAdd +func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: fusion = ["BiasAdd"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %4 : tensor<8x64xf32> +} + +// CHECK-LABEL: matmulBiasAddRelu +func @matmulBiasAddRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: fusion = ["BiasAdd", "Relu"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + %5 = "tf.Relu"(%4) : (tensor<8x64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %5 : tensor<8x64xf32> +} + +// CHECK-LABEL: matmulBiasAddLeakyRelu +func @matmulBiasAddLeakyRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> { + // CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0) + // CHECK-SAME: alpha = 2.000000e-01 : f32 + // CHECK-SAME: fusion = ["BiasAdd", "LeakyRelu"] + // CHECK-SAME: transpose_a = false, transpose_b = false + %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32> + %4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32> + %5 = "tf.LeakyRelu"(%4) { alpha = 0.2 : f32 } : (tensor<8x64xf32>) -> tensor<8x64xf32> + // CHECK: return %[[FUSED]] + return %5 : tensor<8x64xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir index ff4dbf41221..e6a92a520f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -101,7 +101,7 @@ func @decompose_resource_apply_momentum_non_nesterov(%arg0: tensor, %arg1: // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]]) // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]]) - // CHECK: [[ACCUM_NEW:%.*]] = "tf.Add"([[ACCUM_MOMENTUM]], [[GRAD]]) + // CHECK: [[ACCUM_NEW:%.*]] = "tf.AddV2"([[ACCUM_MOMENTUM]], [[GRAD]]) // CHECK: "tf.AssignVariableOp"([[ACCUM_HANDLE]], [[ACCUM_NEW]]) // CHECK: [[ACCUM_NEW_LR:%.*]] = "tf.Mul"([[ACCUM_NEW]], [[LR]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) @@ -127,12 +127,12 @@ func @decompose_resource_apply_momentum_nesterov(%arg0: tensor, %arg1: tens // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]]) // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]]) - // CHECK: [[ACCUM_NEW:%.*]] = "tf.Add"([[ACCUM_MOMENTUM]], [[GRAD]]) + // CHECK: [[ACCUM_NEW:%.*]] = "tf.AddV2"([[ACCUM_MOMENTUM]], [[GRAD]]) // CHECK: "tf.AssignVariableOp"([[ACCUM_HANDLE]], [[ACCUM_NEW]]) // CHECK: [[GRAD_LR:%.*]] = "tf.Mul"([[GRAD]], [[LR]]) // CHECK: [[MOMENTUM_LR:%.*]] = "tf.Mul"([[MOMENTUM]], [[LR]]) // CHECK: [[ACCUM_NEW_MOMENTUM_LR:%.*]] = "tf.Mul"([[ACCUM_NEW]], [[MOMENTUM_LR]]) - // CHECK: [[DELTA:%.*]] = "tf.Add"([[GRAD_LR]], [[ACCUM_NEW_MOMENTUM_LR]]) + // CHECK: [[DELTA:%.*]] = "tf.AddV2"([[GRAD_LR]], [[ACCUM_NEW_MOMENTUM_LR]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[DELTA]]) // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]]) @@ -231,6 +231,31 @@ func @decompose_resource_apply_adagradv2(%arg0: tensor, %arg1: tensor, return } +// ----- +// CHECK-LABEL: func @decompose_resource_apply_adagrad +// CHECK-SAME: (%[[LR:.*]]: tensor, %[[GRAD:.*]]: tensor) +func @decompose_resource_apply_adagrad(%arg0: tensor, %arg1: tensor) -> () { + + // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + // CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor, tensor) -> tensor + // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> + // CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK: %[[LR_MULTIPLY:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor, tensor) -> tensor + // CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ACCUM_NEW]]) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> + // CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor) -> () + + return +} + // ----- // Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is @@ -388,14 +413,14 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens // CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]]) // CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]]) // CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]]) - // CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]]) + // CHECK: [[MS_NEW:%.*]] = "tf.AddV2"([[GRAD_SUB]], [[MS_RHO]]) // CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]]) // CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]]) // CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]]) // CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]]) // CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]]) - // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]]) + // CHECK: [[MG_NEW:%.*]] = "tf.AddV2"([[SUB_GRAD]], [[MG_RHO]]) // CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]]) // CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]]) @@ -403,11 +428,11 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens // CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) // CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]]) - // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]]) + // CHECK: [[MG_NEW:%.*]] = "tf.AddV2"([[MG_MG]], [[EPSILON]]) // CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]]) // CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]]) // CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]]) - // CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]]) + // CHECK: [[MOM_NEW:%.*]] = "tf.AddV2"([[MOM_MOM]], [[MOM_DIV]]) // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]]) @@ -416,6 +441,33 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tens "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor, tensor, tensor) -> () return } +// ----- +// CHECK-LABEL: func @decompose_resource_apply_RMS_prop +// CHECK-SAME: (%[[VAR_HANDLE:.*]]: tensor<*x!tf.resource>, %[[MS_HANDLE:.*]]: tensor<*x!tf.resource>, %[[MOM_HANDLE:.*]]: tensor<*x!tf.resource>, +// CHECK-SAME: %[[LR:.*]]: tensor, %[[RHO:.*]]: tensor, %[[MOMENTUM:.*]]: tensor, %[[EPSILON:.*]]: tensor, %[[GRAD:.*]]: tensor) +func @decompose_resource_apply_RMS_prop(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<*x!tf.resource>, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) -> () { +// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[MS:.*]] = "tf.ReadVariableOp"(%[[MS_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[MS_RHO:.*]] = "tf.Mul"(%[[MS]], %[[RHO]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: %[[GRAD_SQUARE:.*]] = "tf.Square"(%[[GRAD]]) : (tensor) -> tensor +// CHECK: %[[ONE_RHO:.*]] = "tf.Sub"(%[[ONE]], %[[RHO]]) : (tensor, tensor) -> tensor +// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[GRAD_SQUARE]], %[[ONE_RHO]]) : (tensor, tensor) -> tensor +// CHECK: %[[MS_NEW:.*]] = "tf.AddV2"(%[[MS_RHO]], %[[MUL]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[MS_HANDLE]], %[[MS_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () +// CHECK: %[[MOM:.*]] = "tf.ReadVariableOp"(%[[MOM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[MOMENTUM_MOM:.*]] = "tf.Mul"(%[[MOMENTUM]], %[[MOM]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[LR_GRAD:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor, tensor) -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MS_NEW]], %[[EPSILON]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ADD]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_GRAD]], %[[SQRT]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[MOM_NEW:.*]] = "tf.AddV2"(%[[MOMENTUM_MOM]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[MOM_HANDLE]], %[[MOM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () +// CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[MOM_NEW]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + "tf.ResourceApplyRMSProp"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor, tensor, tensor) -> () + return +} // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir new file mode 100644 index 00000000000..8250bcf7101 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir @@ -0,0 +1,16 @@ +// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @fold_identity +// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32> +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} { + func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tf_executor.graph { + // CHECK: tf.MatMul + %outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NOT: tf.Identity + %outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32> + tf_executor.fetch %outputs_0 : tensor<2x2xf32> + } + return %0 : tensor<2x2xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir index bec48181b3b..726495f1fbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -220,7 +220,7 @@ func @merge_islands_only() { %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 - tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> + tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32> tf_executor.fetch } return @@ -244,7 +244,7 @@ func @merge_islands_only() { // CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) // CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> // CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]] -// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] +// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] // Test no merging took place as cycle would be formed otherwise. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir index 7d761b5d690..0000d43823b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir @@ -16,7 +16,7 @@ module { "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () %index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor %input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor - %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor, tensor) -> tensor + %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4], is_stateless = false} : (tensor, tensor) -> tensor tf_executor.yield %result : tensor } tf_executor.fetch %output : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir index c8c82c5c08f..e4e7f0859c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir @@ -123,6 +123,27 @@ func @testIfNoInputAndNoResult(%arg0: tensor) -> () { // ----- +// If with non tensor condition + +// Simple If +// CHECK: func @testIf1Then{{.+}} +// CHECK: func @testIf1Else{{.+}} +func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> +func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> + +// CHECK-LABEL: func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) +func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.If"(%arg0, %arg1) { + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false + } : (tensor, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[ToBool:%.*]] = "tf.ToBool" + // CHECK: "tf.IfRegion"([[ToBool]]) + return %0 : tensor<*xf32> +} + +// ----- + // Simple While func @testWhileCond(tensor<*xf32>) -> (tensor) func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) @@ -200,3 +221,58 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { return %1 : tensor<*xf32> } +// ----- + +// While with non tensor condition +func @testWhileCond(tensor<*xf32>) -> (tensor) +func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) + +// CHECK-LABEL: func @testWhileResult +func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + %1 = "tf.While"(%arg0) { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = true, + _attr0 = 10, _attr1 = true, attr2 = "hello" + } : (tensor<*xf32>) -> (tensor<*xf32>) + + // CHECK: [[Result0:%.*]] = "tf.WhileRegion" + // CHECK: [[Result1:%.*]] = call @testWhileCond + // CHECK: [[ToBool:%.*]] = "tf.ToBool"([[Result1]]) + // CHECK: "tf.Yield"([[ToBool]]) + // CHECK: [[Result2:%.*]] = call @testWhileBody + // CHECK: "tf.Yield"([[Result2]]) + // CHECK: return [[Result0]] + return %1 : tensor<*xf32> +} + +// ----- + +func @then_branch() -> () +func @else_branch() -> () + +// Test tf.If device is preserved. +// CHECK-LABEL: func @testIfDevice +func @testIfDevice(%arg0: tensor) { + "tf.If"(%arg0) {then_branch = @then_branch, else_branch = @else_branch, is_stateless = false, device = "/device:CPU:0"} : (tensor) -> () + + // CHECK: "tf.IfRegion" + // CHECK: device = "/device:CPU:0" + return +} + +// ----- + +func @cond() -> tensor +func @body() -> () + +// Test tf.While device is preserved. +// CHECK-LABEL: func @testWhileDevice +func @testWhileDevice() { + "tf.While"() {cond = @cond, body = @body, is_stateless = false, device = "/device:CPU:0"} : () -> () + + // CHECK: "tf.WhileRegion" + // CHECK: device = "/device:CPU:0" + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt new file mode 100644 index 00000000000..1372ad71283 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt @@ -0,0 +1,261 @@ +# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s + +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "indexed_case" + op: "StatelessCase" + input: "Const_1" + input: "Const" + attr { + key: "Tin" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "_lower_using_switch_merge" + value { + b: true + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "branches" + value { + list { + func { + name: "indexed_case_branch0_4" + } + func { + name: "indexed_case_branch1_5" + } + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "indexed_case/Identity" + op: "Identity" + input: "indexed_case" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +library { + function { + signature { + name: "indexed_case_branch0_4" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } + function { + signature { + name: "indexed_case_branch1_5" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } +} +versions { + producer: 486 + min_consumer: 12 +} + +# CHECK: tf.Case +# CHECK-SAME: is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index e21fd901a9e..a6b1979ee26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -7,7 +7,7 @@ # CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source # CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]] -# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]] +# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]] node { name: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index 30599b2e437..9bb05a75877 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -7,7 +7,7 @@ // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -18,7 +18,7 @@ func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32 // CHECK-SAME: strides = [5, 8, 6, 7] // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -38,7 +38,7 @@ func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32 func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<*xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -49,7 +49,7 @@ func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: ten // CHECK-SAME: strides = [5, 8, 6, 7] // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -77,7 +77,7 @@ func @transposeConv2DBackpropFilter( // CHECK-SAME: dst_format = "NCHW" // CHECK-SAME: src_format = "NHWC" - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[IN_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) @@ -117,7 +117,7 @@ func @transposeConv2DBackpropInput( // CHECK-SAME: dst_format = "NCHW" // CHECK-SAME: src_format = "NHWC" - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]]) // CHECK: %[[CONV2D_BACKPROP:[0-9]*]] = "tf.Conv2DBackpropInput" @@ -130,7 +130,7 @@ func @transposeConv2DBackpropInput( // CHECK-SAME: (tensor<4xi32>, tensor<1x1x3x8xf32>, tensor<1x8x32x32xf32>) // CHECK-SAME: -> tensor<1x3x32x32xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D_BACKPROP]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -154,7 +154,7 @@ func @transposeFusedBatchNormV3( ) -> tensor<1x28x28x64xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: "tf.FusedBatchNormV3" @@ -164,7 +164,7 @@ func @transposeFusedBatchNormV3( // CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -192,7 +192,7 @@ func @transposeFusedBatchNormGradV3( ) -> tensor<1x28x28x64xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]]) @@ -204,7 +204,7 @@ func @transposeFusedBatchNormGradV3( // CHECK-SAME: -> (tensor<1x64x28x28xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose" // CHECK-SAME: (%x_backprop, %[[RES_PERM]]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index e6b3bf08394..c71d8ef2850 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -7,7 +7,7 @@ // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> { - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) @@ -18,7 +18,7 @@ func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32 // CHECK-SAME: strides = [5, 7, 8, 6] // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] @@ -41,7 +41,7 @@ func @transposeFusedBatchNormV3( ) -> tensor<1x64x28x28xf32> { // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) // CHECK: "tf.FusedBatchNormV3" @@ -51,7 +51,7 @@ func @transposeFusedBatchNormV3( // CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>, // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() - // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) // CHECK: return %[[RES_TRANSPOSE]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir index 0b1e27733eb..bacfeea2dc9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir @@ -65,3 +65,40 @@ func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> return %3 : tensor<1x8x4x4xf32> } + +// CHECK-LABEL: move_transpose_handle_broadcast +func @move_transpose_handle_broadcast(%arg0:tensor<8x64xf32>, %arg1:tensor<8x64x64xf32>) -> tensor<512x64xf32> { + %cst = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_2 = "tf.Const"() {value = dense<[512, 64]> : tensor<2xi32>} : () -> tensor<2xi32> + %0 = "tf.ExpandDims"(%arg0, %cst) {device = ""} : (tensor<8x64xf32>, tensor) -> tensor<8x64x1xf32> + %1 = "tf.AddV2"(%0, %arg1) {device = ""} : (tensor<8x64x1xf32>, tensor<8x64x64xf32>) -> tensor<8x64x64xf32> + %2 = "tf.Transpose"(%1, %cst_1) {device = ""} : (tensor<8x64x64xf32>, tensor<3xi32>) -> tensor<64x8x64xf32> + %3 = "tf.Reshape"(%2, %cst_2) {device = ""} : (tensor<64x8x64xf32>, tensor<2xi32>) -> tensor<512x64xf32> + + return %3 : tensor<512x64xf32> + + // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: %[[CST_2:.*]] = "tf.Const"() {value = dense<[512, 64]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[EXPAND_DIMS:.*]] = "tf.ExpandDims"(%arg0, %[[CST_1]]) {device = ""} : (tensor<8x64xf32>, tensor) -> tensor<8x64x1xf32> + // CHECK: %[[TRANSPOSE_1:.*]] = "tf.Transpose"(%[[EXPAND_DIMS]], %[[CST_0]]) : (tensor<8x64x1xf32>, tensor<3xi32>) -> tensor<1x8x64xf32> + // CHECK: %[[TRANSPOSE_2:.*]] = "tf.Transpose"(%arg1, %[[CST_0]]) : (tensor<8x64x64xf32>, tensor<3xi32>) -> tensor<64x8x64xf32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[TRANSPOSE_1]], %[[TRANSPOSE_2]]) {device = ""} : (tensor<1x8x64xf32>, tensor<64x8x64xf32>) -> tensor<64x8x64xf32> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%[[ADD]], %[[CST_2]]) {device = ""} : (tensor<64x8x64xf32>, tensor<2xi32>) -> tensor<512x64xf32> + // CHECK: return %[[RESHAPE]] : tensor<512x64xf32> +} + +// CHECK-LABEL: dont_move_transpose_different_ranks +func @dont_move_transpose_different_ranks(%arg0:tensor<1x1x2x3xf32>, %arg1:tensor<2x3xf32>) -> tensor<1x2x1x3xf32> { + %cst = "tf.Const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<1x1x2x3xf32>, tensor<2x3xf32>) -> tensor<1x1x2x3xf32> + %1 = "tf.Transpose"(%0, %cst) {device = ""} : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + + return %1 : tensor<1x2x1x3xf32> + + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<1x1x2x3xf32>, tensor<2x3xf32>) -> tensor<1x1x2x3xf32> + // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[ADD]], %[[CST]]) {device = ""} : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + // CHECK: return %[[TRANSPOSE]] : tensor<1x2x1x3xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4f044cd5eff..9864cffee7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1,177 +1,396 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // RUN: tf-opt -tf-legalize-hlo %s | FileCheck %s +// CHECK-LABEL: func @biasAdd_NHWC( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x10x32xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> +// CHECK: } func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } +// CHECK-LABEL: func @biasAdd_NCHW( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x10x32xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> +// CHECK: } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } +// CHECK-LABEL: func @biasAdd_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_1]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_2]] : tensor<2xi32> +// CHECK: } func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> %1 = mhlo.add %0, %arg0 : tensor<2xi32> return %1 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @broadcast_multi_dim_add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x1x1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> +// CHECK: return %[[VAL_2]] : tensor<4x4x4x4xi32> +// CHECK: } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } +// CHECK-LABEL: func @div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @shift_left( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.LeftShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @div_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @maximum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[VAL_2]] : tensor<4xf32> +// CHECK: } func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } +// CHECK-LABEL: func @minimum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[VAL_2]] : tensor<4xf32> +// CHECK: } func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } +// CHECK-LABEL: func @mul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_mul( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @real_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_real_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @sub( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @broadcast_sub( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return %[[VAL_2]] : tensor<1x2xi32> +// CHECK: } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } +// CHECK-LABEL: func @shift_right( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @broadcast_shift_right( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: return %[[VAL_2]] : tensor<2x4xi32> +// CHECK: } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = mhlo.and %arg0, %arg0 : tensor<2xi1> +// CHECK-LABEL: func @and( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.and %arg0, %arg1 : tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @and_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi1>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = mhlo.or %arg0, %arg0 : tensor<2xi1> +// CHECK-LABEL: func @or( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> +// CHECK: } +func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.or %arg0, %arg1 : tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @or_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @or_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi1>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitwise_or( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.or %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @bitwise_or_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return %[[VAL_2]] : tensor<1x4xi8> +// CHECK: } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } +// CHECK-LABEL: func @bitwise_or_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitwise_and( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return %[[VAL_2]] : tensor<4xi32> +// CHECK: } func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.and %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } +// CHECK-LABEL: func @bitwise_and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return %[[VAL_2]] : tensor<1x4xi8> +// CHECK: } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } +// CHECK-LABEL: func @bitwise_and_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @pow( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.power %arg0, %arg0 : tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @pow_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @pow_dynamic(%arg0: tensor) -> tensor { %0 = mhlo.power %arg0, %arg0 : tensor return %0 : tensor } +// CHECK-LABEL: func @floordiv_broadcast_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xi32>) -> tensor<2x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> +// CHECK: %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return %[[VAL_16]] : tensor<2x3xi32> +// CHECK: } func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = mhlo.constant dense<0> : tensor<2x3xi32> %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> @@ -191,6 +410,26 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te return %14 : tensor<2x3xi32> } +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> +// CHECK: %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return %[[VAL_16]] : tensor<2x3xi32> +// CHECK: } func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = mhlo.constant dense<0> : tensor<3xi32> %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> @@ -210,6 +449,13 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 return %14 : tensor<2x3xi32> } +// CHECK-LABEL: func @floordiv_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_3:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_3]] : tensor<2xf32> +// CHECK: } func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.divide %arg0, %arg0 : tensor<2xf32> %1 = mhlo.divide %arg0, %arg0 : tensor<2xf32> @@ -217,6 +463,14 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %2 : tensor<2xf32> } +// CHECK-LABEL: func @floordiv_f16_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf16>) -> tensor<2x3xf16> { +// CHECK: %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: %[[VAL_3:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: %[[VAL_4:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: return %[[VAL_4]] : tensor<2x3xf16> +// CHECK: } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -224,118 +478,252 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te return %2 : tensor<2x3xf16> } +// CHECK-LABEL: func @equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @equal_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @equal_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @notequal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @notequal_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @greater( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_greater( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @greater_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_greater_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @less( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_less( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @less_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @broadcast_less_equal( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return %[[VAL_2]] : tensor<1x2xi1> +// CHECK: } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @concat_v2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> +// CHECK: return %[[VAL_3]] : tensor<6x3xf32> +// CHECK: } func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> return %2 : tensor<6x3xf32> } +// CHECK-LABEL: func @concat_v2_1d_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> +// CHECK: return %[[VAL_3]] : tensor<3x6xf32> +// CHECK: } func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> return %2 : tensor<3x6xf32> } +// CHECK-LABEL: func @const() -> tensor<2xi32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: return %[[VAL_0]] : tensor<2xi32> +// CHECK: } func @const() -> tensor<2xi32> { %0 = mhlo.constant dense<0> : tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @relu( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> +// CHECK: return %[[VAL_2]] : tensor<1xi32> +// CHECK: } func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } +// CHECK-LABEL: func @relu_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } func @relu_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } +// CHECK-LABEL: func @relu6( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: return %[[VAL_4]] : tensor<1xi32> +// CHECK: } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor @@ -344,6 +732,14 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { return %3 : tensor<1xi32> } +// CHECK-LABEL: func @relu6_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor @@ -352,6 +748,15 @@ func @relu6_unranked(%arg0: tensor) -> tensor { return %3 : tensor } +// CHECK-LABEL: func @relu_grad( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<4x8xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Greater"(%[[VAL_1]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32> +// CHECK: %[[VAL_5:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_0]], %[[VAL_4]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK: return %[[VAL_5]] : tensor<4x8xf32> +// CHECK: } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor @@ -360,31 +765,74 @@ func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf3 return %3 : tensor<4x8xf32> } +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @select_float( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_3]] : tensor<2xf32> +// CHECK: } func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @select_multidimensional( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +// CHECK: return %[[VAL_3]] : tensor<3x2xi32> +// CHECK: } func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> return %0 : tensor<3x2xi32> } +// CHECK-LABEL: func @selectv2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @selectv2_pred_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return %[[VAL_3]] : tensor<2xi32> +// CHECK: } func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @transpose_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2xf32> +// CHECK: } func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -392,6 +840,14 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { return %2 : tensor<3x2xf32> } +// CHECK-LABEL: func @transpose_3d_int32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2x1xf32> +// CHECK: } func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> @@ -399,6 +855,14 @@ func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { return %2 : tensor<3x2x1xf32> } +// CHECK-LABEL: func @transpose_3d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_4]] : tensor<3x2x1xf32> +// CHECK: } func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> @@ -406,6 +870,14 @@ func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { return %2 : tensor<3x2x1xf32> } +// CHECK-LABEL: func @transpose_dynamic_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor<4x?xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> +// CHECK: return %[[VAL_4]] : tensor<4x?xf32> +// CHECK: } func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -413,6 +885,14 @@ func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { return %2 : tensor<4x?xf32> } +// CHECK-LABEL: func @transpose_unranked_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> +// CHECK: return %[[VAL_4]] : tensor<*xf32> +// CHECK: } func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> @@ -420,146 +900,297 @@ func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %2 : tensor<*xf32> } +// CHECK-LABEL: func @abs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @abs_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @abs_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @abs_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @ceil( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @ceil_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @ceil_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @ceil_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @complex_abs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xcomplex>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.ComplexAbs"(%[[VAL_0]]) : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @cos( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @cos_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @cos_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.cosine"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @cos_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @exp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @exp_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @exp_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @exp_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @floor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @floor_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @floor_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @floor_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @is_finite( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi1> +// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK: } func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> return %0 : tensor<2xi1> } +// CHECK-LABEL: func @is_finite_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @is_finite_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.is_finite"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @is_finite_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xi1> { +// CHECK: %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xi1> +// CHECK: return %[[VAL_1]] : tensor<*xi1> +// CHECK: } func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { %0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> return %0 : tensor<*xi1> } +// CHECK-LABEL: func @log( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @log_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @log_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.log"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @log_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @log1p( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @log1p_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @log1p_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @log1p_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @neg( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @neg_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @neg_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.negate"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @neg_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @sigmoid( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_5:.*]] = "tf.Tanh"(%[[VAL_4]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Mul"(%[[VAL_5]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_7:.*]] = "tf.AddV2"(%[[VAL_6]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_7]] : tensor<2xf32> +// CHECK: } func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.constant dense<5.000000e-01> : tensor %1 = mhlo.constant dense<2> : tensor<1xi64> @@ -571,86 +1202,177 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %6 : tensor<2xf32> } +// CHECK-LABEL: func @sin( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @sin_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @sin_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.sine"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @sin_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @rsqrt_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @rsqrt_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @rsqrt_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @sqrt( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @sqrt_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @sqrt_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.sqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @sqrt_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @tanh( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @tanh_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @tanh_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @tanh_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @bitcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @bitcast_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @bitcast_dynamic(%arg0: tensor) -> tensor { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @bitcast_unranked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_1]] : tensor<*xf32> +// CHECK: } func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @bitcast_same_widths( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi32> { +// CHECK: %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi32> +// CHECK: return %[[VAL_1]] : tensor<2xi32> +// CHECK: } func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } +// CHECK-LABEL: func @sign( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_3:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_5:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_1]], %[[VAL_2]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: return %[[VAL_7]] : tensor<1x2x3x4xf32> +// CHECK: } func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> @@ -662,72 +1384,180 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { return %6 : tensor<1x2x3x4xf32> } +// CHECK-LABEL: func @size_rank_one_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @size_rank_one_i32(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<1> : tensor return %0 : tensor } +// CHECK-LABEL: func @size_rank_one_i64( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return %[[VAL_1]] : tensor +// CHECK: } func @size_rank_one_i64(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<1> : tensor return %0 : tensor } +// CHECK-LABEL: func @complex( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xcomplex> { +// CHECK: %[[VAL_2:.*]] = "tf.Complex"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> +// CHECK: return %[[VAL_2]] : tensor<3xcomplex> +// CHECK: } func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> return %0 : tensor<3xcomplex> } +// CHECK-LABEL: func @convert_i32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Cast"(%[[VAL_0]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: func @convert_slice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> +// CHECK: return %[[VAL_3]] : tensor<1x519xf32> +// CHECK: } func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> return %0 : tensor<1x519xf32> } +// CHECK-LABEL: func @reshape( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[2, 2, 6]> : tensor<3xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32> +// CHECK: return %[[VAL_2]] : tensor<2x2x6xf32> +// CHECK: } func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> return %0 : tensor<2x2x6xf32> } +// CHECK-LABEL: func @convert_dot_1d_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = "tf.MatMul"(%[[VAL_3]], %[[VAL_1]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = constant dense<1> : tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_6]] : tensor<1xf32> +// CHECK: } func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } +// CHECK-LABEL: func @convert_dot_2d_1d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = "tf.MatMul"(%[[VAL_0]], %[[VAL_3]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = constant dense<1> : tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_6]] : tensor<1xf32> +// CHECK: } func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } +// CHECK-LABEL: func @convert_dot_1d_1d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor { +// CHECK: %[[VAL_2:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_4:.*]] = constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[VAL_6:.*]] = "tf.MatMul"(%[[VAL_3]], %[[VAL_5]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = constant dense<> : tensor<0xi64> +// CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor +// CHECK: return %[[VAL_8]] : tensor +// CHECK: } func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor return %0 : tensor } +// CHECK-LABEL: func @convert_dot_2d_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.MatMul"(%[[VAL_0]], %[[VAL_1]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<1x1xf32> +// CHECK: } func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } +// CHECK-LABEL: func @broadcast_in_dim_tf_style( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[VAL_1]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<3x8x8x16xf32> +// CHECK: } func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @broadcast_in_dim_general_case( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { +// CHECK: %[[VAL_1:.*]] = constant dense<[3, 1, 1, 16]> : tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32> +// CHECK: %[[VAL_3:.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> +// CHECK: %[[VAL_4:.*]] = "tf.BroadcastTo"(%[[VAL_2]], %[[VAL_3]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> +// CHECK: return %[[VAL_4]] : tensor<3x8x8x16xf32> +// CHECK: } func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @convert_dot_general( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 3, 4, 1, 2]> : tensor<5xi64>} : () -> tensor<5xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32> +// CHECK: %[[VAL_6:.*]] = constant dense<[3, 5, 12]> : tensor<3xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> +// CHECK: %[[VAL_8:.*]] = constant dense<[3, 12, 4]> : tensor<3xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> +// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_9]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[VAL_11:.*]] = constant dense<[3, 5, 1, 4]> : tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> +// CHECK: return %[[VAL_12]] : tensor<3x5x1x4xf32> +// CHECK: } func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> return %0 : tensor<3x5x1x4xf32> } +// CHECK-LABEL: func @convert_conv2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -736,6 +1566,12 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32> return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_depthwise_conv2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -744,6 +1580,12 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2 return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_conv2d_valid_padding( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32> +// CHECK: } func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, @@ -752,6 +1594,13 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3 return %0 : tensor<1x8x8x16xf32> } +// CHECK-LABEL: func @convert_reduce_to_sum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "mhlo.reduce"(%arg0, %0) ( { @@ -762,6 +1611,13 @@ func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } +// CHECK-LABEL: func @convert_reduce_to_max( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0xFF800000> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0xFF800000" represents -INF for f32. %0 = mhlo.constant dense<0xFF800000> : tensor @@ -773,7 +1629,13 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } - +// CHECK-LABEL: func @convert_reduce_to_min( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return %[[VAL_3]] : tensor<1xf32> +// CHECK: } func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0x7F800000" represents INF for f32. %0 = mhlo.constant dense<0x7F800000> : tensor @@ -785,928 +1647,31 @@ func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { return %1 : tensor<1xf32> } +// CHECK-LABEL: func @convert_iota_1d() -> tensor<123xf32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<1.230000e+02> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor, tensor) -> tensor<123xf32> +// CHECK: return %[[VAL_3]] : tensor<123xf32> +// CHECK: } +func @convert_iota_1d() -> tensor<123xf32> { + %0 = "mhlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<123xf32> + return %0 : tensor<123xf32> +} + +// CHECK-LABEL: func @convert_iota_3d() -> tensor<5x7x9xi32> { +// CHECK: %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<7> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor, tensor, tensor) -> tensor<7xi32> +// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[1, 7, 1]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_4]]) : (tensor<7xi32>, tensor<3xi64>) -> tensor<1x7x1xi32> +// CHECK: %[[VAL_6:.*]] = "tf.Const"() {value = dense<[5, 7, 9]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_7:.*]] = "tf.BroadcastTo"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32> +// CHECK: return %[[VAL_7]] : tensor<5x7x9xi32> +// CHECK: } +func @convert_iota_3d() -> tensor<5x7x9xi32> { + %0 = "mhlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<5x7x9xi32> + return %0 : tensor<5x7x9xi32> +} - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @biasAdd_NHWC( -// CHECK-SAME: [[VAL_0:%.*]]: tensor<1x32x10x32xi32>, [[VAL_1:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { -// CHECK: [[VAL_2:%.*]] = "tf.AddV2"([[VAL_0]], [[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> -// CHECK: return [[VAL_2]] : tensor<1x32x10x32xi32> -// CHECK: } - -// CHECK-LABEL: func @biasAdd_NCHW( -// CHECK-SAME: [[VAL_3:%.*]]: tensor<1x32x10x32xi32>, [[VAL_4:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { -// CHECK: [[VAL_5:%.*]] = "tf.AddV2"([[VAL_3]], [[VAL_4]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> -// CHECK: return [[VAL_5]] : tensor<1x32x10x32xi32> -// CHECK: } - -// CHECK-LABEL: func @biasAdd_dynamic( -// CHECK-SAME: [[VAL_6:%.*]]: tensor, [[VAL_7:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_8:%.*]] = "tf.AddV2"([[VAL_6]], [[VAL_7]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_8]] : tensor -// CHECK: } - -// CHECK-LABEL: func @add( -// CHECK-SAME: [[VAL_9:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_10:%.*]] = "tf.AddV2"([[VAL_9]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: [[VAL_11:%.*]] = "tf.AddV2"([[VAL_10]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_11]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_add( -// CHECK-SAME: [[VAL_12:%.*]]: tensor<1xi32>, [[VAL_13:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_13]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_14]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_multi_dim_add( -// CHECK-SAME: [[VAL_15:%.*]]: tensor<4x1x1xi32>, [[VAL_16:%.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { -// CHECK: [[VAL_17:%.*]] = "tf.AddV2"([[VAL_15]], [[VAL_16]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> -// CHECK: return [[VAL_17]] : tensor<4x4x4x4xi32> -// CHECK: } - -// CHECK-LABEL: func @div( -// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_19:%.*]] = "tf.Div"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_19]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_div( -// CHECK-SAME: [[VAL_20:%.*]]: tensor<1xi32>, [[VAL_21:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_22]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @shift_left( -// CHECK-SAME: [[VAL_23:%.*]]: tensor<4xi32>, [[VAL_24:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_25:%.*]] = "tf.LeftShift"([[VAL_23]], [[VAL_24]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_25]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @div_dynamic( -// CHECK-SAME: [[VAL_26:%.*]]: tensor, [[VAL_27:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_28:%.*]] = "tf.Div"([[VAL_26]], [[VAL_27]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_28]] : tensor -// CHECK: } - -// CHECK-LABEL: func @maximum( -// CHECK-SAME: [[VAL_29:%.*]]: tensor<4xf32>, [[VAL_30:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_31:%.*]] = "tf.Maximum"([[VAL_29]], [[VAL_30]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_31]] : tensor<4xf32> -// CHECK: } - -// CHECK-LABEL: func @minimum( -// CHECK-SAME: [[VAL_32:%.*]]: tensor<4xf32>, [[VAL_33:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_34:%.*]] = "tf.Minimum"([[VAL_32]], [[VAL_33]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_34]] : tensor<4xf32> -// CHECK: } - -// CHECK-LABEL: func @mul( -// CHECK-SAME: [[VAL_35:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_36:%.*]] = "tf.Mul"([[VAL_35]], [[VAL_35]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_36]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_mul( -// CHECK-SAME: [[VAL_37:%.*]]: tensor<1xi32>, [[VAL_38:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_39:%.*]] = "tf.Mul"([[VAL_37]], [[VAL_38]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_39]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @real_div( -// CHECK-SAME: [[VAL_40:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_41:%.*]] = "tf.Div"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_41]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_real_div( -// CHECK-SAME: [[VAL_42:%.*]]: tensor<1xi32>, [[VAL_43:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_44:%.*]] = "tf.Div"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_44]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @sub( -// CHECK-SAME: [[VAL_45:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_46:%.*]] = "tf.Sub"([[VAL_45]], [[VAL_45]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_46]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_sub( -// CHECK-SAME: [[VAL_47:%.*]]: tensor<1xi32>, [[VAL_48:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_49:%.*]] = "tf.Sub"([[VAL_47]], [[VAL_48]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_49]] : tensor<1x2xi32> -// CHECK: } - -// CHECK-LABEL: func @shift_right( -// CHECK-SAME: [[VAL_50:%.*]]: tensor<4xi32>, [[VAL_51:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_52:%.*]] = "tf.RightShift"([[VAL_50]], [[VAL_51]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_52]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_shift_right( -// CHECK-SAME: [[VAL_53:%.*]]: tensor<4xi32>, [[VAL_54:%.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { -// CHECK: [[VAL_55:%.*]] = "tf.RightShift"([[VAL_53]], [[VAL_54]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK: return [[VAL_55]] : tensor<2x4xi32> -// CHECK: } - -// CHECK-LABEL: func @and( -// CHECK-SAME: [[VAL_56:%.*]]: tensor<2xi1>) -> tensor<2xi1> { -// CHECK: [[VAL_57:%.*]] = "tf.LogicalAnd"([[VAL_56]], [[VAL_56]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> -// CHECK: return [[VAL_57]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @and_broadcast( -// CHECK-SAME: [[VAL_58:%.*]]: tensor<1xi1>, [[VAL_59:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { -// CHECK: [[VAL_60:%.*]] = "tf.LogicalAnd"([[VAL_58]], [[VAL_59]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> -// CHECK: return [[VAL_60]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @and_dynamic( -// CHECK-SAME: [[VAL_61:%.*]]: tensor, [[VAL_62:%.*]]: tensor<1xi1>) -> tensor { -// CHECK: [[VAL_63:%.*]] = "tf.LogicalAnd"([[VAL_61]], [[VAL_62]]) : (tensor, tensor<1xi1>) -> tensor -// CHECK: return [[VAL_63]] : tensor -// CHECK: } - -// CHECK-LABEL: func @or( -// CHECK-SAME: [[VAL_64:%.*]]: tensor<2xi1>) -> tensor<2xi1> { -// CHECK: [[VAL_65:%.*]] = "tf.LogicalOr"([[VAL_64]], [[VAL_64]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> -// CHECK: return [[VAL_65]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @or_broadcast( -// CHECK-SAME: [[VAL_66:%.*]]: tensor<1xi1>, [[VAL_67:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { -// CHECK: [[VAL_68:%.*]] = "tf.LogicalOr"([[VAL_66]], [[VAL_67]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> -// CHECK: return [[VAL_68]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @or_dynamic( -// CHECK-SAME: [[VAL_69:%.*]]: tensor, [[VAL_70:%.*]]: tensor<1xi1>) -> tensor { -// CHECK: [[VAL_71:%.*]] = "tf.LogicalOr"([[VAL_69]], [[VAL_70]]) : (tensor, tensor<1xi1>) -> tensor -// CHECK: return [[VAL_71]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitwise_or( -// CHECK-SAME: [[VAL_72:%.*]]: tensor<4xi32>, [[VAL_73:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_74:%.*]] = "tf.BitwiseOr"([[VAL_72]], [[VAL_73]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_74]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @bitwise_or_broadcast( -// CHECK-SAME: [[VAL_75:%.*]]: tensor<1xi8>, [[VAL_76:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { -// CHECK: [[VAL_77:%.*]] = "tf.BitwiseOr"([[VAL_75]], [[VAL_76]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> -// CHECK: return [[VAL_77]] : tensor<1x4xi8> -// CHECK: } - -// CHECK-LABEL: func @bitwise_or_dynamic( -// CHECK-SAME: [[VAL_78:%.*]]: tensor, [[VAL_79:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_80:%.*]] = "tf.BitwiseOr"([[VAL_78]], [[VAL_79]]) : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_80]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitwise_and( -// CHECK-SAME: [[VAL_81:%.*]]: tensor<4xi32>, [[VAL_82:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_83:%.*]] = "tf.BitwiseAnd"([[VAL_81]], [[VAL_82]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_83]] : tensor<4xi32> -// CHECK: } - -// CHECK-LABEL: func @bitwise_and_broadcast( -// CHECK-SAME: [[VAL_84:%.*]]: tensor<1xi8>, [[VAL_85:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { -// CHECK: [[VAL_86:%.*]] = "tf.BitwiseAnd"([[VAL_84]], [[VAL_85]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> -// CHECK: return [[VAL_86]] : tensor<1x4xi8> -// CHECK: } - -// CHECK-LABEL: func @bitwise_and_dynamic( -// CHECK-SAME: [[VAL_87:%.*]]: tensor, [[VAL_88:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_89:%.*]] = "tf.BitwiseAnd"([[VAL_87]], [[VAL_88]]) : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_89]] : tensor -// CHECK: } - -// CHECK-LABEL: func @pow( -// CHECK-SAME: [[VAL_90:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_91:%.*]] = "tf.Pow"([[VAL_90]], [[VAL_90]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_91]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @pow_dynamic( -// CHECK-SAME: [[VAL_92:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_93:%.*]] = "tf.Pow"([[VAL_92]], [[VAL_92]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_93]] : tensor -// CHECK: } - -// CHECK-LABEL: func @floordiv_broadcast_i32( -// CHECK-SAME: [[VAL_94:%.*]]: tensor<2x3xi32>, [[VAL_95:%.*]]: tensor<3xi32>) -> tensor<2x3xi32> { -// CHECK: [[VAL_96:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_97:%.*]] = "tf.Less"([[VAL_94]], [[VAL_96]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> -// CHECK: [[VAL_98:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_99:%.*]] = "tf.Less"([[VAL_95]], [[VAL_98]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> -// CHECK: [[VAL_100:%.*]] = "tf.Equal"([[VAL_97]], [[VAL_99]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_101:%.*]] = "tf.Div"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_102:%.*]] = "tf.Abs"([[VAL_94]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_103:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_104:%.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_105:%.*]] = "tf.Sub"([[VAL_103]], [[VAL_104]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_106:%.*]] = "tf.AddV2"([[VAL_102]], [[VAL_105]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_107:%.*]] = "tf.Neg"([[VAL_106]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_108:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_109:%.*]] = "tf.Div"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_110:%.*]] = "tf.Select"([[VAL_100]], [[VAL_101]], [[VAL_109]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: return [[VAL_110]] : tensor<2x3xi32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_reverse_broadcast_i32( -// CHECK-SAME: [[VAL_111:%.*]]: tensor<3xi32>, [[VAL_112:%.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { -// CHECK: [[VAL_113:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_114:%.*]] = "tf.Less"([[VAL_111]], [[VAL_113]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> -// CHECK: [[VAL_115:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_116:%.*]] = "tf.Less"([[VAL_112]], [[VAL_115]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> -// CHECK: [[VAL_117:%.*]] = "tf.Equal"([[VAL_114]], [[VAL_116]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_118:%.*]] = "tf.Div"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_119:%.*]] = "tf.Abs"([[VAL_111]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_120:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_121:%.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> -// CHECK: [[VAL_122:%.*]] = "tf.Sub"([[VAL_120]], [[VAL_121]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_123:%.*]] = "tf.AddV2"([[VAL_119]], [[VAL_122]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_124:%.*]] = "tf.Neg"([[VAL_123]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_125:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_126:%.*]] = "tf.Div"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_127:%.*]] = "tf.Select"([[VAL_117]], [[VAL_118]], [[VAL_126]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: return [[VAL_127]] : tensor<2x3xi32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_f32( -// CHECK-SAME: [[VAL_128:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_129:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_130:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_131:%.*]] = "tf.FloorDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_131]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @floordiv_f16_broadcast( -// CHECK-SAME: [[VAL_132:%.*]]: tensor<2x3xf16>, [[VAL_133:%.*]]: tensor<3xf16>) -> tensor<2x3xf16> { -// CHECK: [[VAL_134:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: [[VAL_135:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: [[VAL_136:%.*]] = "tf.FloorDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: return [[VAL_136]] : tensor<2x3xf16> -// CHECK: } - -// CHECK-LABEL: func @equal( -// CHECK-SAME: [[VAL_137:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_138:%.*]] = "tf.Equal"([[VAL_137]], [[VAL_137]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_138]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_dynamic( -// CHECK-SAME: [[VAL_139:%.*]]: tensor, [[VAL_140:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_141:%.*]] = "tf.Equal"([[VAL_139]], [[VAL_140]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_141]] : tensor -// CHECK: } - -// CHECK-LABEL: func @equal_broadcast( -// CHECK-SAME: [[VAL_142:%.*]]: tensor<1xi32>, [[VAL_143:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_144:%.*]] = "tf.Equal"([[VAL_142]], [[VAL_143]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_144]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error( -// CHECK-SAME: [[VAL_145:%.*]]: tensor<2xi32>, [[VAL_146:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_147:%.*]] = "tf.Equal"([[VAL_145]], [[VAL_146]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_147]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable( -// CHECK-SAME: [[VAL_148:%.*]]: tensor, [[VAL_149:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_150:%.*]] = "tf.Equal"([[VAL_148]], [[VAL_149]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_150]] : tensor -// CHECK: } - -// CHECK-LABEL: func @notequal( -// CHECK-SAME: [[VAL_151:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_152:%.*]] = "tf.NotEqual"([[VAL_151]], [[VAL_151]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_152]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_broadcast( -// CHECK-SAME: [[VAL_153:%.*]]: tensor<1xi32>, [[VAL_154:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_155:%.*]] = "tf.NotEqual"([[VAL_153]], [[VAL_154]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_155]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error( -// CHECK-SAME: [[VAL_156:%.*]]: tensor<2xi32>, [[VAL_157:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_158:%.*]] = "tf.NotEqual"([[VAL_156]], [[VAL_157]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_158]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable( -// CHECK-SAME: [[VAL_159:%.*]]: tensor, [[VAL_160:%.*]]: tensor<1xi32>) -> tensor { -// CHECK: [[VAL_161:%.*]] = "tf.NotEqual"([[VAL_159]], [[VAL_160]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor -// CHECK: return [[VAL_161]] : tensor -// CHECK: } - -// CHECK-LABEL: func @greater( -// CHECK-SAME: [[VAL_162:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_163:%.*]] = "tf.Greater"([[VAL_162]], [[VAL_162]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_163]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_greater( -// CHECK-SAME: [[VAL_164:%.*]]: tensor<1xi32>, [[VAL_165:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_166:%.*]] = "tf.Greater"([[VAL_164]], [[VAL_165]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_166]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @greater_equal( -// CHECK-SAME: [[VAL_167:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_168:%.*]] = "tf.GreaterEqual"([[VAL_167]], [[VAL_167]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_168]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_greater_equal( -// CHECK-SAME: [[VAL_169:%.*]]: tensor<1xi32>, [[VAL_170:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_171:%.*]] = "tf.GreaterEqual"([[VAL_169]], [[VAL_170]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_171]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @less( -// CHECK-SAME: [[VAL_172:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_173:%.*]] = "tf.Less"([[VAL_172]], [[VAL_172]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_173]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_less( -// CHECK-SAME: [[VAL_174:%.*]]: tensor<1xi32>, [[VAL_175:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_176:%.*]] = "tf.Less"([[VAL_174]], [[VAL_175]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_176]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @less_equal( -// CHECK-SAME: [[VAL_177:%.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: [[VAL_178:%.*]] = "tf.LessEqual"([[VAL_177]], [[VAL_177]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return [[VAL_178]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @broadcast_less_equal( -// CHECK-SAME: [[VAL_179:%.*]]: tensor<1xi32>, [[VAL_180:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { -// CHECK: [[VAL_181:%.*]] = "tf.LessEqual"([[VAL_179]], [[VAL_180]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> -// CHECK: return [[VAL_181]] : tensor<1x2xi1> -// CHECK: } - -// CHECK-LABEL: func @concat_v2( -// CHECK-SAME: [[VAL_182:%.*]]: tensor<3x3xf32>, [[VAL_183:%.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> { -// CHECK: [[VAL_184:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_185:%.*]] = "tf.ConcatV2"([[VAL_182]], [[VAL_183]], [[VAL_184]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> -// CHECK: return [[VAL_185]] : tensor<6x3xf32> -// CHECK: } - -// CHECK-LABEL: func @concat_v2_1d_axis( -// CHECK-SAME: [[VAL_186:%.*]]: tensor<3x3xf32>, [[VAL_187:%.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> { -// CHECK: [[VAL_188:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: [[VAL_189:%.*]] = "tf.ConcatV2"([[VAL_186]], [[VAL_187]], [[VAL_188]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> -// CHECK: return [[VAL_189]] : tensor<3x6xf32> -// CHECK: } - -// CHECK-LABEL: func @const() -> tensor<2xi32> { -// CHECK: [[VAL_190:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK: return [[VAL_190]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @relu( -// CHECK-SAME: [[VAL_192:%.*]]: tensor<1xi32>) -> tensor<1xi32> { -// CHECK: [[VAL_193:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_194:%.*]] = "tf.Maximum"([[VAL_193]], [[VAL_192]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> -// CHECK: return [[VAL_194]] : tensor<1xi32> -// CHECK: } - -// CHECK-LABEL: func @relu_unranked( -// CHECK-SAME: [[VAL_195:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_196:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_197:%.*]] = "tf.Maximum"([[VAL_196]], [[VAL_195]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_197]] : tensor -// CHECK: } - -// CHECK-LABEL: func @relu6( -// CHECK-SAME: [[VAL_198:%.*]]: tensor<1xi32>) -> tensor<1xi32> { -// CHECK: [[VAL_199:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_200:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor -// CHECK: [[VAL_201:%.*]] = "tf.Minimum"([[VAL_198]], [[VAL_200]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> -// CHECK: [[VAL_202:%.*]] = "tf.Maximum"([[VAL_201]], [[VAL_199]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> -// CHECK: return [[VAL_202]] : tensor<1xi32> -// CHECK: } - -// CHECK-LABEL: func @relu6_unranked( -// CHECK-SAME: [[VAL_203:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_204:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_205:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor -// CHECK: [[VAL_206:%.*]] = "tf.Minimum"([[VAL_203]], [[VAL_205]]) : (tensor, tensor) -> tensor -// CHECK: [[VAL_207:%.*]] = "tf.Maximum"([[VAL_206]], [[VAL_204]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_207]] : tensor -// CHECK: } - -// CHECK-LABEL: func @relu_grad( -// CHECK-SAME: [[VAL_208:%.*]]: tensor<4x8xf32>, [[VAL_209:%.*]]: tensor) -> tensor<4x8xf32> { -// CHECK: [[VAL_210:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor -// CHECK: [[VAL_211:%.*]] = "tf.Greater"([[VAL_209]], [[VAL_210]]) : (tensor, tensor) -> tensor -// CHECK: [[VAL_212:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32> -// CHECK: [[VAL_213:%.*]] = "tf.Select"([[VAL_211]], [[VAL_208]], [[VAL_212]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> -// CHECK: return [[VAL_213]] : tensor<4x8xf32> -// CHECK: } - -// CHECK-LABEL: func @select( -// CHECK-SAME: [[VAL_214:%.*]]: tensor<2xi1>, [[VAL_215:%.*]]: tensor<2xi32>, [[VAL_216:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_217:%.*]] = "tf.Select"([[VAL_214]], [[VAL_215]], [[VAL_216]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_217]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @select_float( -// CHECK-SAME: [[VAL_218:%.*]]: tensor<2xi1>, [[VAL_219:%.*]]: tensor<2xf32>, [[VAL_220:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_221:%.*]] = "tf.Select"([[VAL_218]], [[VAL_219]], [[VAL_220]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_221]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @select_multidimensional( -// CHECK-SAME: [[VAL_222:%.*]]: tensor<3x2xi1>, [[VAL_223:%.*]]: tensor<3x2xi32>, [[VAL_224:%.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> { -// CHECK: [[VAL_225:%.*]] = "tf.Select"([[VAL_222]], [[VAL_223]], [[VAL_224]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> -// CHECK: return [[VAL_225]] : tensor<3x2xi32> -// CHECK: } - -// CHECK-LABEL: func @selectv2( -// CHECK-SAME: [[VAL_226:%.*]]: tensor<2xi1>, [[VAL_227:%.*]]: tensor<2xi32>, [[VAL_228:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_229:%.*]] = "tf.Select"([[VAL_226]], [[VAL_227]], [[VAL_228]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_229]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @selectv2_pred_scalar( -// CHECK-SAME: [[VAL_230:%.*]]: tensor, [[VAL_231:%.*]]: tensor<2xi32>, [[VAL_232:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_233:%.*]] = "tf.Select"([[VAL_230]], [[VAL_231]], [[VAL_232]]) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_233]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @transpose_2d( -// CHECK-SAME: [[VAL_234:%.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> { -// CHECK: [[VAL_235:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_236:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_237:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_238:%.*]] = "tf.Transpose"([[VAL_234]], [[VAL_237]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> -// CHECK: return [[VAL_238]] : tensor<3x2xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_3d_int32( -// CHECK-SAME: [[VAL_239:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -// CHECK: [[VAL_240:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: [[VAL_241:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_242:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_243:%.*]] = "tf.Transpose"([[VAL_239]], [[VAL_242]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> -// CHECK: return [[VAL_243]] : tensor<3x2x1xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_3d( -// CHECK-SAME: [[VAL_244:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -// CHECK: [[VAL_245:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_246:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_247:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK: [[VAL_248:%.*]] = "tf.Transpose"([[VAL_244]], [[VAL_247]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> -// CHECK: return [[VAL_248]] : tensor<3x2x1xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_dynamic_2d( -// CHECK-SAME: [[VAL_249:%.*]]: tensor) -> tensor<4x?xf32> { -// CHECK: [[VAL_250:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_251:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_252:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_253:%.*]] = "tf.Transpose"([[VAL_249]], [[VAL_252]]) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> -// CHECK: return [[VAL_253]] : tensor<4x?xf32> -// CHECK: } - -// CHECK-LABEL: func @transpose_unranked_2d( -// CHECK-SAME: [[VAL_254:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_255:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_256:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_257:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_258:%.*]] = "tf.Transpose"([[VAL_254]], [[VAL_257]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> -// CHECK: return [[VAL_258]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @abs( -// CHECK-SAME: [[VAL_259:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_260:%.*]] = "tf.Abs"([[VAL_259]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_260]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @abs_dynamic( -// CHECK-SAME: [[VAL_261:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_262:%.*]] = "tf.Abs"([[VAL_261]]) : (tensor) -> tensor -// CHECK: return [[VAL_262]] : tensor -// CHECK: } - -// CHECK-LABEL: func @abs_unranked( -// CHECK-SAME: [[VAL_263:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_264:%.*]] = "tf.Abs"([[VAL_263]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_264]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @ceil( -// CHECK-SAME: [[VAL_265:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_266:%.*]] = "tf.Ceil"([[VAL_265]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_266]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @ceil_dynamic( -// CHECK-SAME: [[VAL_267:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_268:%.*]] = "tf.Ceil"([[VAL_267]]) : (tensor) -> tensor -// CHECK: return [[VAL_268]] : tensor -// CHECK: } - -// CHECK-LABEL: func @ceil_unranked( -// CHECK-SAME: [[VAL_269:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_270:%.*]] = "tf.Ceil"([[VAL_269]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_270]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @complex_abs( -// CHECK-SAME: [[VAL_271:%.*]]: tensor<2xcomplex>) -> tensor<2xf32> { -// CHECK: [[VAL_272:%.*]] = "tf.ComplexAbs"([[VAL_271]]) : (tensor<2xcomplex>) -> tensor<2xf32> -// CHECK: return [[VAL_272]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @cos( -// CHECK-SAME: [[VAL_273:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_274:%.*]] = "tf.Cos"([[VAL_273]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_274]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @cos_dynamic( -// CHECK-SAME: [[VAL_275:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_276:%.*]] = "tf.Cos"([[VAL_275]]) : (tensor) -> tensor -// CHECK: return [[VAL_276]] : tensor -// CHECK: } - -// CHECK-LABEL: func @cos_unranked( -// CHECK-SAME: [[VAL_277:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_278:%.*]] = "tf.Cos"([[VAL_277]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_278]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @exp( -// CHECK-SAME: [[VAL_279:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_280:%.*]] = "tf.Exp"([[VAL_279]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_280]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @exp_dynamic( -// CHECK-SAME: [[VAL_281:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_282:%.*]] = "tf.Exp"([[VAL_281]]) : (tensor) -> tensor -// CHECK: return [[VAL_282]] : tensor -// CHECK: } - -// CHECK-LABEL: func @exp_unranked( -// CHECK-SAME: [[VAL_283:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_284:%.*]] = "tf.Exp"([[VAL_283]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_284]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @floor( -// CHECK-SAME: [[VAL_285:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_286:%.*]] = "tf.Floor"([[VAL_285]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_286]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @floor_dynamic( -// CHECK-SAME: [[VAL_287:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_288:%.*]] = "tf.Floor"([[VAL_287]]) : (tensor) -> tensor -// CHECK: return [[VAL_288]] : tensor -// CHECK: } - -// CHECK-LABEL: func @floor_unranked( -// CHECK-SAME: [[VAL_289:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_290:%.*]] = "tf.Floor"([[VAL_289]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_290]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @is_finite( -// CHECK-SAME: [[VAL_291:%.*]]: tensor<2xf32>) -> tensor<2xi1> { -// CHECK: [[VAL_292:%.*]] = "tf.IsFinite"([[VAL_291]]) : (tensor<2xf32>) -> tensor<2xi1> -// CHECK: return [[VAL_292]] : tensor<2xi1> -// CHECK: } - -// CHECK-LABEL: func @is_finite_dynamic( -// CHECK-SAME: [[VAL_293:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_294:%.*]] = "tf.IsFinite"([[VAL_293]]) : (tensor) -> tensor -// CHECK: return [[VAL_294]] : tensor -// CHECK: } - -// CHECK-LABEL: func @is_finite_unranked( -// CHECK-SAME: [[VAL_295:%.*]]: tensor<*xf32>) -> tensor<*xi1> { -// CHECK: [[VAL_296:%.*]] = "tf.IsFinite"([[VAL_295]]) : (tensor<*xf32>) -> tensor<*xi1> -// CHECK: return [[VAL_296]] : tensor<*xi1> -// CHECK: } - -// CHECK-LABEL: func @log( -// CHECK-SAME: [[VAL_297:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_298:%.*]] = "tf.Log"([[VAL_297]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_298]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @log_dynamic( -// CHECK-SAME: [[VAL_299:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_300:%.*]] = "tf.Log"([[VAL_299]]) : (tensor) -> tensor -// CHECK: return [[VAL_300]] : tensor -// CHECK: } - -// CHECK-LABEL: func @log_unranked( -// CHECK-SAME: [[VAL_301:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_302:%.*]] = "tf.Log"([[VAL_301]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_302]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @log1p( -// CHECK-SAME: [[VAL_303:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_304:%.*]] = "tf.Log1p"([[VAL_303]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_304]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @log1p_dynamic( -// CHECK-SAME: [[VAL_305:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_306:%.*]] = "tf.Log1p"([[VAL_305]]) : (tensor) -> tensor -// CHECK: return [[VAL_306]] : tensor -// CHECK: } - -// CHECK-LABEL: func @log1p_unranked( -// CHECK-SAME: [[VAL_307:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_308:%.*]] = "tf.Log1p"([[VAL_307]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_308]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @neg( -// CHECK-SAME: [[VAL_309:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_310:%.*]] = "tf.Neg"([[VAL_309]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_310]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @neg_dynamic( -// CHECK-SAME: [[VAL_311:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_312:%.*]] = "tf.Neg"([[VAL_311]]) : (tensor) -> tensor -// CHECK: return [[VAL_312]] : tensor -// CHECK: } - -// CHECK-LABEL: func @neg_unranked( -// CHECK-SAME: [[VAL_313:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_314:%.*]] = "tf.Neg"([[VAL_313]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_314]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @sigmoid( -// CHECK-SAME: [[VAL_315:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_316:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor -// CHECK: [[VAL_317:%.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_318:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> -// CHECK: [[VAL_319:%.*]] = "tf.Mul"([[VAL_315]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_320:%.*]] = "tf.Tanh"([[VAL_319]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_321:%.*]] = "tf.Mul"([[VAL_320]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_322:%.*]] = "tf.AddV2"([[VAL_321]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_322]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sin( -// CHECK-SAME: [[VAL_323:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_324:%.*]] = "tf.Sin"([[VAL_323]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_324]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sin_dynamic( -// CHECK-SAME: [[VAL_325:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_326:%.*]] = "tf.Sin"([[VAL_325]]) : (tensor) -> tensor -// CHECK: return [[VAL_326]] : tensor -// CHECK: } - -// CHECK-LABEL: func @sin_unranked( -// CHECK-SAME: [[VAL_327:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_328:%.*]] = "tf.Sin"([[VAL_327]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_328]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @rsqrt( -// CHECK-SAME: [[VAL_329:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_330:%.*]] = "tf.Rsqrt"([[VAL_329]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_330]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @rsqrt_dynamic( -// CHECK-SAME: [[VAL_331:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_332:%.*]] = "tf.Rsqrt"([[VAL_331]]) : (tensor) -> tensor -// CHECK: return [[VAL_332]] : tensor -// CHECK: } - -// CHECK-LABEL: func @rsqrt_unranked( -// CHECK-SAME: [[VAL_333:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_334:%.*]] = "tf.Rsqrt"([[VAL_333]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_334]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @sqrt( -// CHECK-SAME: [[VAL_335:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_336:%.*]] = "tf.Sqrt"([[VAL_335]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_336]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @sqrt_dynamic( -// CHECK-SAME: [[VAL_337:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_338:%.*]] = "tf.Sqrt"([[VAL_337]]) : (tensor) -> tensor -// CHECK: return [[VAL_338]] : tensor -// CHECK: } - -// CHECK-LABEL: func @sqrt_unranked( -// CHECK-SAME: [[VAL_339:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_340:%.*]] = "tf.Sqrt"([[VAL_339]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_340]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @tanh( -// CHECK-SAME: [[VAL_341:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_342:%.*]] = "tf.Tanh"([[VAL_341]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_342]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @tanh_dynamic( -// CHECK-SAME: [[VAL_343:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_344:%.*]] = "tf.Tanh"([[VAL_343]]) : (tensor) -> tensor -// CHECK: return [[VAL_344]] : tensor -// CHECK: } - -// CHECK-LABEL: func @tanh_unranked( -// CHECK-SAME: [[VAL_345:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_346:%.*]] = "tf.Tanh"([[VAL_345]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_346]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast( -// CHECK-SAME: [[VAL_347:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_348:%.*]] = "tf.Bitcast"([[VAL_347]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_348]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast_dynamic( -// CHECK-SAME: [[VAL_349:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_350:%.*]] = "tf.Bitcast"([[VAL_349]]) : (tensor) -> tensor -// CHECK: return [[VAL_350]] : tensor -// CHECK: } - -// CHECK-LABEL: func @bitcast_unranked( -// CHECK-SAME: [[VAL_351:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_352:%.*]] = "tf.Bitcast"([[VAL_351]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_352]] : tensor<*xf32> -// CHECK: } - -// CHECK-LABEL: func @bitcast_same_widths( -// CHECK-SAME: [[VAL_353:%.*]]: tensor<2xf32>) -> tensor<2xi32> { -// CHECK: [[VAL_354:%.*]] = "tf.Bitcast"([[VAL_353]]) : (tensor<2xf32>) -> tensor<2xi32> -// CHECK: return [[VAL_354]] : tensor<2xi32> -// CHECK: } - -// CHECK-LABEL: func @sign( -// CHECK-SAME: [[VAL_355:%.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { -// CHECK: [[VAL_356:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: [[VAL_357:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_358:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: [[VAL_359:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_360:%.*]] = "tf.Sign"([[VAL_355]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_361:%.*]] = "tf.Select"([[VAL_358]], [[VAL_359]], [[VAL_360]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: [[VAL_362:%.*]] = "tf.Select"([[VAL_356]], [[VAL_357]], [[VAL_361]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: return [[VAL_362]] : tensor<1x2x3x4xf32> -// CHECK: } - -// CHECK-LABEL: func @size_rank_one_i32( -// CHECK-SAME: [[VAL_363:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_364:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: return [[VAL_364]] : tensor -// CHECK: } - -// CHECK-LABEL: func @size_rank_one_i64( -// CHECK-SAME: [[VAL_365:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: return [[VAL_366]] : tensor -// CHECK: } - -// CHECK-LABEL: func @complex( -// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex> { -// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> -// CHECK: return [[VAL_369]] : tensor<3xcomplex> -// CHECK: } - -// CHECK-LABEL: func @convert_i32_f32( -// CHECK-SAME: [[VAL_370:%.*]]: tensor<2xi32>) -> tensor<2xf32> { -// CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> -// CHECK: return [[VAL_371]] : tensor<2xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_slice( -// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { -// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> -// CHECK: return [[VAL_375]] : tensor<1x519xf32> -// CHECK: } - -// CHECK-LABEL: func @reshape( -// CHECK-SAME: [[VAL_372:%.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> { -// CHECK: [[VAL_373:%.*]] = constant dense<[2, 2, 6]> : tensor<3xi64> -// CHECK: [[VAL_374:%.*]] = "tf.Reshape"([[VAL_372]], [[VAL_373]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32> -// CHECK: return [[VAL_374]] : tensor<2x2x6xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_1d_2d( -// CHECK-SAME: [[VAL_376:%.*]]: tensor<256xf32>, [[VAL_377:%.*]]: tensor<256x1xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_378:%.*]] = "tf.Reshape"([[VAL_376]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_379:%.*]] = "tf.MatMul"([[VAL_378]], [[VAL_377]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_380:%.*]] = "tf.Reshape"([[VAL_379]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_380]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_2d_1d( -// CHECK-SAME: [[VAL_381:%.*]]: tensor<1x256xf32>, [[VAL_382:%.*]]: tensor<256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_383:%.*]] = "tf.Reshape"([[VAL_382]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_384:%.*]] = "tf.MatMul"([[VAL_381]], [[VAL_383]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_385:%.*]] = "tf.Reshape"([[VAL_384]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_385]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_1d_1d( -// CHECK-SAME: [[VAL_386:%.*]]: tensor<256xf32>, [[VAL_387:%.*]]: tensor<256xf32>) -> tensor { -// CHECK-DAG: [[VAL_388:%.*]] = "tf.Reshape"([[VAL_386]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK-DAG: [[VAL_389:%.*]] = "tf.Reshape"([[VAL_387]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: [[VAL_390:%.*]] = "tf.MatMul"([[VAL_388]], [[VAL_389]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> -// CHECK: [[VAL_391:%.*]] = "tf.Reshape"([[VAL_390]], {{.*}}) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor -// CHECK: return [[VAL_391]] : tensor -// CHECK: } - -// CHECK-LABEL: func @convert_dot_2d_2d( -// CHECK-SAME: [[VAL_392:%.*]]: tensor<1x256xf32>, [[VAL_393:%.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { -// CHECK: [[VAL_394:%.*]] = "tf.MatMul"([[VAL_392]], [[VAL_393]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> -// CHECK: return [[VAL_394]] : tensor<1x1xf32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_in_dim_tf_style( -// CHECK-SAME: [[VAL_395:%.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { -// CHECK: [[VAL_396:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> -// CHECK: [[VAL_397:%.*]] = "tf.BroadcastTo"([[VAL_395]], [[VAL_396]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> -// CHECK: return [[VAL_397]] : tensor<3x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @broadcast_in_dim_general_case( -// CHECK-SAME: [[VAL_398:%.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { -// CHECK: [[VAL_399:%.*]] = constant dense<[3, 1, 1, 16]> : tensor<4xi64> -// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], [[VAL_399]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32> -// CHECK: [[VAL_401:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64> -// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> -// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_dot_general( -// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { -// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32> -// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32> -// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> -// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> -// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> -// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> -// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_conv2d( -// CHECK-SAME: [[VAL_404:%.*]]: tensor<1x8x8x207xf32>, [[VAL_405:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_406:%.*]] = "tf.Conv2D"([[VAL_404]], [[VAL_405]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_406]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_depthwise_conv2d( -// CHECK-SAME: [[VAL_407:%.*]]: tensor<1x8x8x207xf32>, [[VAL_408:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_409:%.*]] = "tf.DepthwiseConv2dNative"([[VAL_407]], [[VAL_408]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_409]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_conv2d_valid_padding( -// CHECK-SAME: [[VAL_410:%.*]]: tensor<1x8x8x207xf32>, [[VAL_411:%.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { -// CHECK: [[VAL_412:%.*]] = "tf.Conv2D"([[VAL_410]], [[VAL_411]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> -// CHECK: return [[VAL_412]] : tensor<1x8x8x16xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_sum( -// CHECK-SAME: [[VAL_413:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_414:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_415:%.*]] = "tf.Sum"([[VAL_413:%.*]], [[VAL_414:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_415]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_max( -// CHECK-SAME: [[VAL_416:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_417:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_418:%.*]] = "tf.Max"([[VAL_416:%.*]], [[VAL_417:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_418]] : tensor<1xf32> -// CHECK: } - -// CHECK-LABEL: func @convert_reduce_to_min( -// CHECK-SAME: [[VAL_419:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> { -// CHECK: [[VAL_420:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK: [[VAL_421:%.*]] = "tf.Min"([[VAL_419:%.*]], [[VAL_420:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: return [[VAL_421]] : tensor<1xf32> -// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index e7e07845fcc..155f84ecc37 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -215,6 +215,60 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor< return %0 : tensor<*xf32> } +// %input has 1 batch dimension then 2 block dimensions then 1 remainder +// dimension. +// CHECK-LABEL: fourdim_SpaceToBatchND +func @fourdim_SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { + // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>} + // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]]) + // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} + // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]]) + // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]]) + // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>} + // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) + // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]]) + // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1) + // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]]) + // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>} + // CHECK-DAG: [[OUTPUT_BATCH_PART:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART]], [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]]) + // CHECK-DAG: [[RESHAPED:%.+]] = "tf.Reshape"([[PADDED]], [[RESHAPED_SHAPE]]) + // CHECK-DAG: [[PERMUTED:%.+]] = "tf.Transpose"([[RESHAPED]], [[PERMUTATION]]) + // CHECK-DAG: [[RESULT:%.+]] = "tf.Reshape"([[PERMUTED]], [[OUTPUT_SHAPE]]) + // CHECK-DAG: return [[RESULT]] + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// %input has 1 batch dimension then 3 block dimensions then 2 remainder +// dimensions. This checks only ops that are specific to the case with 3 block +// dimension and 2 remainder dimensions. +// CHECK-LABEL: sixdim_SpaceToBatchND +func @sixdim_SpaceToBatchND(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: tensor<3xi64>, %paddings: tensor<3x2xi64>) -> tensor<*xf32> { + // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() + // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[PAD00]], {{.+}}) + // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 9, 10, 11]> : tensor<6xi64>} + // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:6 = "tf.Split" + // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:3 = "tf.Split" + // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTER_SHAPE_2:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#3, [[BLOCK_SHAPE_SPLITS]]#2) + // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[OUTER_SHAPE_2]], [[BLOCK_SHAPE_SPLITS]]#2, [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}}) + // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 6, 0, 1, 3, 5, 7, 8]> : tensor<9xi64>} + // CHECK-DAG: [[OUTPUT_BATCH_PART1:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0) + // CHECK-DAG: [[OUTPUT_BATCH_PART2:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART1]], [[BLOCK_SHAPE_SPLITS]]#1) + // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART2]], [[BLOCK_SHAPE_SPLITS]]#2) + // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[OUTER_SHAPE_2]], [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}}) + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x9x10x11xf32>, tensor<3xi64>, tensor<3x2xi64>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: SoftmaxCrossEntropyWithLogits // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[LABELS:.*]]: tensor<2x3xf32> func @SoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) { @@ -353,8 +407,16 @@ func @ZerosLike_variant(%arg0: tensor>>) -> tensor>> } -// CHECK-LABEL: func @addN -func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: func @addN_2 +func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // return %[[SUM0]] + %0 = "tf.AddN"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addN_3 +func @addN_3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2) // return %[[SUM1]] @@ -362,6 +424,27 @@ func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @addN_4 +func @addN_4(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3) + // CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]]) + // return %[[SUM2]] + %0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addN_5 +func @addN_5(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3) + // CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]]) + // CHECK: %[[SUM3:.*]] = "tf.AddV2"(%[[SUM2]], %arg4) + // return %[[SUM3]] + %0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @addN_variant func @addN_variant(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>) -> tensor>> { // CHECK: tf.AddN @@ -450,13 +533,39 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> } -func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: @Reciprocal_i32 +func @Reciprocal_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// CHECK-LABEL: @Reciprocal_f32 +func @Reciprocal_f32(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: @Reciprocal_complexf32 +func @Reciprocal_complexf32(%arg0: tensor<*xcomplex>) -> tensor<*xcomplex> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor>, tensor<*xcomplex>) -> tensor<*xcomplex> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex>) -> tensor<*xcomplex> + return %0 : tensor<*xcomplex> +} + +// CHECK-LABEL: @Reciprocal_complexf64 +func @Reciprocal_complexf64(%arg0: tensor<*xcomplex>) -> tensor<*xcomplex> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor>, tensor<*xcomplex>) -> tensor<*xcomplex> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex>) -> tensor<*xcomplex> + return %0 : tensor<*xcomplex> +} + +// CHECK-LABEL: @ScatterNd func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> @@ -465,3 +574,16 @@ func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32> return %0 : tensor<8xf32> } + +// CHECK-LABEL: @_UnaryOpsComposition +// CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32> +func @_UnaryOpsComposition(%arg0: tensor<4xf32>) -> tensor<4xf32> { + + // CHECK: %[[RESULT0:.*]] = "tf.Asin"(%[[ARG0]]) + // CHECK: %[[RESULT1:.*]] = "tf.Abs"(%[[RESULT0]]) + // CHECK: %[[RESULT2:.*]] = "tf.Log"(%[[RESULT1]]) + // CHECK: return %[[RESULT2]] + + %0 = "tf._UnaryOpsComposition"(%arg0) {op_names = ["Asin", "Abs", "Log"]} : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index 3efa0b09439..dc99d9d6343 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -1,53 +1,286 @@ // RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s - -// CHECK-LABEL: func @op_string_result -func @op_string_result() -> tensor { +// CHECK-LABEL: func @unsupported_op_no_soft_placement +func @unsupported_op_no_soft_placement() -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.A" + // CHECK: "tf.UnsupportedOp" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.B" - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" + // CHECK: "tf.Identity" // CHECK-NOT: _xla_outside_compilation - %1 = "tf.A"() : () -> tensor - %2 = "tf.B"(%1) : (tensor) -> tensor - %3 = "tf.C"(%1) : (tensor) -> tensor - tf_device.return %3 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - return %0 : tensor + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor } -// CHECK-LABEL: func @op_string_operand -func @op_string_operand(%arg0: tensor) -> tensor { +// CHECK-LABEL: func @unsupported_op_soft_placement_false +func @unsupported_op_soft_placement_false() -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.A" + // CHECK: "tf.UnsupportedOp" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.B" + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = false, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @unsupported_op +func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.UnsupportedOp" // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" + // CHECK: "tf.Identity" // CHECK-NOT: _xla_outside_compilation - %1 = "tf.A"() : () -> tensor - %2 = "tf.B"(%arg0) : (tensor) -> tensor - %3 = "tf.C"(%2) : (tensor) -> tensor - tf_device.return %3 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - return %0 : tensor + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @tf2xla_fallback_op +func @tf2xla_fallback_op() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Sinh" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + %4 = "tf.Sinh"(%2) : (tensor) -> tensor + tf_device.return %4 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ignore_embedding_ops +func @ignore_embedding_ops() -> () { + "tf_device.cluster"() ( { + // CHECK: "tf.RecvTPUEmbeddingActivations" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.SendTPUEmbeddingGradients" + // CHECK-NOT: _xla_outside_compilation + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + "tf.SendTPUEmbeddingGradients"(%2#0, %2#1) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> () + tf_device.return + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return +} + +// CHECK-LABEL: func @op_string_result +func @op_string_result() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Const" + // CHECK-SAME: _xla_outside_compilation + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x"> : tensor} : () -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @op_string_operand +func @op_string_operand(%arg0: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.StringToNumber"(%arg0) {out_type = f32} : (tensor) -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor } // CHECK-LABEL: func @op_string_operand_string_result -func @op_string_operand_string_result(%arg0: tensor) -> tensor { +func @op_string_operand_string_result(%arg0: tensor) -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.A" + // CHECK: "tf.Const"() {value = dense<1> : tensor} // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.B" + // CHECK: "tf.Identity" // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" // CHECK-NOT: _xla_outside_compilation - %1 = "tf.A"() : () -> tensor - %2 = "tf.B"(%arg0) : (tensor) -> tensor - %3 = "tf.C"(%1) : (tensor) -> tensor - tf_device.return %3 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - return %0 : tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%arg0) : (tensor) -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation. + +// CHECK-LABEL: func @if_region_captured_string +func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK: "tf.StringToNumber" + // CHECK-NOT: _xla_outside_compilation + // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%4) : (tensor) -> () + }) {is_stateless = true} : (tensor) -> (tensor) + %5 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation. + +// CHECK-LABEL: func @if_region_string_op +func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor} + // CHECK-NEXT: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + %4 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor + %5 = "tf.StringToNumber"(%4) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%5) : (tensor) -> () + // CHECK: {is_stateless + }) {is_stateless = true} : (tensor) -> (tensor) + %6 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %6: tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation. + +// CHECK-LABEL: func @nested_if_region_string_op +func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {value = dense : tensor} + // CHECK-NOT: _xla_outside_compilation + %4 = "tf.Const"() {value = dense : tensor} : () -> tensor + %5 = "tf.IfRegion"(%4)({ + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor} + // CHECK-NEXT: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + %6 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor + %7 = "tf.StringToNumber"(%6) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + %8 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + "tf.Yield"(%8) : (tensor) -> () + // CHECK: {is_stateless + }){is_stateless = true} : (tensor) -> (tensor) + "tf.Yield"(%5) : (tensor) -> () + // CHECK: {is_stateless + }) {is_stateless = true} : (tensor) -> (tensor) + %9 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %9: tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation. + +// CHECK-LABEL: func @while_region_captured_string +func @while_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.WhileRegion" + // CHECK: "tf.StringToNumber" + // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %2:2 = "tf.WhileRegion"(%1, %arg0) ( { + ^bb0(%carg0: tensor, %carg1: tensor): + %limit = constant dense<5> : tensor + %cond = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%3, %sub) : (tensor, tensor) -> () + }) {is_stateless = true} : (tensor, tensor) -> (tensor, tensor) + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %5 = "tf.Identity"(%2#0) : (tensor) -> (tensor) + tf_device.return %5 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that an unsupported op within a tf.WhileRegion is marked for outside compilation. + +// CHECK-LABEL: func @while_region_unsupported_op +func @while_region_unsupported_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.WhileRegion" + %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %2:2 = "tf.WhileRegion"(%1, %arg0) ( { + ^bb0(%carg0: tensor, %carg1: tensor): + %limit = constant dense<5> : tensor + %cond = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + %3 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%4, %sub) : (tensor, tensor) -> () + // CHECK: {is_stateless = true + }) {is_stateless = true} : (tensor, tensor) -> (tensor, tensor) + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %5 = "tf.Identity"(%2#0) : (tensor) -> (tensor) + tf_device.return %5 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir new file mode 100644 index 00000000000..2f2ee6f1286 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 486 : i32}} { + func @main() { + tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_2, %control_3 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = true, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("stateless_case") + %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_6, %control_7 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("regular_case") + tf_executor.fetch + } + return + } + + func @indexed_case_branch0_40(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } + + func @indexed_case_branch1_50(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } +} + +// CHECK: name: "stateless_case" +// CHECK-NEXT: "StatelessCase" +// CHECK: name: "regular_case" +// CHECK-NEXT: "Case" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir index c6543f3121e..09a38b5b5de 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir @@ -43,7 +43,7 @@ func @main() { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK: } - %1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = []} : (tensor) -> tensor<*xf32> loc("Case") + %1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = [], is_stateless = false} : (tensor) -> tensor<*xf32> loc("Case") tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index e9d4e441a10..3e8935b699e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -212,6 +212,28 @@ func @testNoOutputs(%arg0: tensor, %arg1: tensor<*xf32>) -> () { return } +// ----- +// Check ToBool folding for IfRegion +// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Neg" +// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Abs" +// CHECK-LABEL: @testToBoolFold +func @testToBoolFold(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NEXT: "tf.If"(%arg0, %arg1) + // CHECK-SAME: else_branch = @tf.IfRegion_else + // CHECK-SAME: then_branch = @tf.IfRegion_then + %tobool = "tf.ToBool"(%arg0) : (tensor) -> tensor + %0 = "tf.IfRegion"(%tobool) ({ + %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }) {is_stateless = true} : (tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // ----- // Simple WhileRegion @@ -592,3 +614,64 @@ func @testWhileRegionBlockArgMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor // CHECK: return [[Result]]#0 return %0#0 : tensor<*xf32> } + +// ----- + +// Simple trivially transformable while with ToBool +// CHECK: func @while_cond +// CHECK: func @while_body +// CHECK-LABEL: testWhileRegionTrivial +func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor) -> (tensor<*xf32>, tensor) +func @testWhileRegionTrivial(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond_i32 = call @while_cond(%carg0, %carg1) : (tensor<*xf32>, tensor) -> tensor + %cond = "tf.ToBool"(%cond_i32) : (tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy:2 = call @while_body(%barg0, %barg1) : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Test tf.IfRegion device is preserved. +// CHECK-LABEL: func @testIfRegionDevice +func @testIfRegionDevice(%arg0: tensor) { + "tf.IfRegion"(%arg0) ({ + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false, device = "/device:CPU:0"} : (tensor) -> () + + // CHECK: "tf.If" + // CHECK-SAME: device = "/device:CPU:0" + return +} + +// ----- + +// Test tf.WhileRegion device is preserved. +// CHECK-LABEL: func @testWhileRegionDevice +func @testWhileRegionDevice() { + "tf.WhileRegion"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%0) : (tensor) -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false, device = "/device:CPU:0"} : () -> () + + // CHECK: "tf.While" + // CHECK-SAME: device = "/device:CPU:0" + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 9931a45f995..487234ce958 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-replicate-to-island | FileCheck %s +// RUN: tf-opt -split-input-file %s -tf-replicate-to-island | FileCheck %s // Tests per replica island has same control operands as island holding // replicate. @@ -223,3 +223,219 @@ func @replica_id_attr_added(%arg0: tensor, %arg1: tensor // CHECK: "tf.A" // CHECK-NOT: _xla_replica_id // CHECK: tf_executor.fetch + + +// Tests device ordinals are added to `tf._XlaSendFromHost`/`tf._XlaRecvAtHost` +// based on the first TPU core device id. +// CHECK-LABEL: func @device_ordinals +func @device_ordinals(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + %0 = "tf._XlaRecvAtHost"(%arg1) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg1) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf.NoOp" +// CHECK: tf_executor.island +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf.NoOp" + +// ----- + +// Tests functions with replica variant ops reachable from a replicate region +// is cloned and remapped. + +// CHECK-LABEL: func @call_with_replicate_variant_ops +func @call_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALL_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALL_REPLICA_1:@[a-z0-9_]+]] + +func @send_recv(%arg0: tensor<2x!tf.string>) { + %0 = "tf._XlaRecvAtHost"(%arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + return +} + +// CHECK: func [[CALL_REPLICA_0]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[CALL_REPLICA_1]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests transitive functions with replica variant ops reachable from a +// replicate region is cloned and remapped. + +// CHECK-LABEL: func @call_with_replicate_variant_ops +func @call_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]] + +func @callee(%arg0: tensor<2x!tf.string>) { + "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + return +} + +func @send_recv(%arg0: tensor<2x!tf.string>) { + %0 = "tf._XlaRecvAtHost"(%arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + return +} + +// CHECK: func [[CALLEE_REPLICA_0]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[TRANSITIVE_CALLEE_REPLICA_0:@[a-z0-9_]+]] + +// CHECK: func [[TRANSITIVE_CALLEE_REPLICA_0]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[CALLEE_REPLICA_1]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[TRANSITIVE_CALLEE_REPLICA_1:@[a-z0-9_]+]] + +// CHECK: func [[TRANSITIVE_CALLEE_REPLICA_1]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests functional control flow functions with replica variant ops reachable +// from a replicate region is cloned and remapped. Only the branches reachable +// with replica variant ops are cloned. + +// CHECK-LABEL: func @control_flow_with_replicate_variant_ops +func @control_flow_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg4: tensor, [%arg1, %arg1] as %arg5: tensor, [%arg2, %arg2] as %arg6: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + %0 = "tf.If"(%arg4, %arg5, %arg6, %arg3) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor, tensor<2x!tf.string>) -> tensor + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.If" +// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.If" +// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]] + +func @cond_false(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x!tf.string>) -> tensor { + return %arg0 : tensor +} + +// CHECK-NOT: func @cond_false.+( + +func @cond_true(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x!tf.string>) -> tensor { + "tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + %0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + return %0 : tensor +} + +// CHECK: func [[COND_TRUE_REPLICA_0]] +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[COND_TRUE_REPLICA_1]] +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests function with no replica variant ops reachable from a replicate region +// is not cloned. + +// CHECK-LABEL: func @no_replicate_variant_ops +func @no_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = @send_recv + +func @send_recv(%arg0: tensor<2x!tf.string>) { + "tf.NoOp"() : () -> () + return +} + +// CHECK-NOT: @send_recv.+( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir new file mode 100644 index 00000000000..e857831e6be --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir @@ -0,0 +1,363 @@ +// RUN: tf-opt -split-input-file -tf-test-resource-alias-analysis -verify-diagnostics %s | FileCheck %s + +// Test 2 resources that do not alias. + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @non_aliasing_reads_writes +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @non_aliasing_reads_writes( + %arg0: !tf_res, + %arg1: !tf_res, + %arg2: tensor<32xf32>) -> (tensor<32xf32>) { + %graph = tf_executor.graph { + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + %read0 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<32xf32>) -> () + %read1 = "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor<32xf32> + // expected-remark@below {{Result #0, ID 0 : 0}} + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read2 = "tf.ReadVariableOp"(%var_handle) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg1, %read0) : (!tf_res, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg0, %read2) : (!tf_res, tensor<32xf32>) -> () + %read3 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor<32xf32> + tf_executor.yield %read3 : tensor<32xf32> + } + tf_executor.fetch %island#0 : tensor<32xf32> + } + return %graph : tensor<32xf32> +} + +// ----- +// Tests aliasing of the two resource handles that refer to the same variable. + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @aliasing_reads_writes +func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 0, 1, 2}} + %vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (!tf_res, tensor<32xf32>) -> (!tf_res, tensor<32xf32>) + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh1_id#0, %arg0) : (!tf_res, tensor<32xf32>) -> () + %read1 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + %read2 = "tf.ReadVariableOp"(%vh1) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh0, %read2) : (!tf_res, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%vh1_id#0, %read1) : (!tf_res, tensor<32xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// ----- +// Test an unknown op that has a resource result is marked unknown + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @unknown_resource_op +func @unknown_resource_op(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %0 = "tf.UnknownVarHandleOp"() : () -> !tf_res +} + +// ----- +// Test aliasing through IfOp + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @if_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +func @if_op_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5}} + // expected-remark@below {{Result #2, ID 3 : 0, 1, 2, 3, 5}} + %if:3 = "tf.If"(%read0, %arg1, %vh0) { + then_branch = @if_then, else_branch = @if_else, is_stateless = true + } : (tensor<32xf32>, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 2 : 0, 1, 2}} +// expected-remark@below {{Region #0, Arg #1, ID 3 : 0, 3}} +func @if_then(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %u0, %id0, %id0 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @if_else(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %id0, %id0, %arg1 : !tf_res, !tf_res, !tf_res +} + +// ----- +// Test aliasing through CaseOp + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @case_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +func @case_op_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5}} + // expected-remark@below {{Result #2, ID 3 : 0, 1, 2, 3, 5}} + %if:3 = "tf.Case"(%read0, %arg1, %vh0) { + branches = [@case_branch0, @case_branch1, @case_branch2], + is_stateless = true + } : (tensor, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 2 : 0, 1, 2}} +// expected-remark@below {{Region #0, Arg #1, ID 3 : 0, 3}} +func @case_branch0(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %u0, %id0, %id0 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @case_branch1(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %id0, %id0, %arg1 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 0 : 0}} +// expected-remark@below {{Region #0, Arg #1, ID 1 : 1}} +func @case_branch2(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + return %arg0, %arg0, %arg1 : !tf_res, !tf_res, !tf_res +} + +// ----- +// Test aliasing through WhileOp +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @while_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +// expected-remark@below {{Region #0, Arg #2, ID 6 : 1, 2, 3, 6}} +func @while_op_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5, 6}} + // expected-remark@below {{Result #2, ID 3 : 1, 2, 3, 5, 6}} + %w:3 = "tf.While"(%arg0, %arg1, %arg2) { + body = @while_body, cond = @while_cond, is_stateless = false + } : (!tf_res, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// CHECK-LABEL: func @while_body +// Return 0 : new unknown resource +// Return 1 : arg2 +// Return 2 : arg1 +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 0, 2}} +// expected-remark@below {{Region #0, Arg #2, ID 3 : 0, 3}} +func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + return %u0, %arg2, %arg1 : !tf_res, !tf_res, !tf_res +} + +// CHECK-LABEL: func @while_cond +// expected-remark@below {{Region #0, Arg #0, ID 0 : 0}} +// expected-remark@below {{Region #0, Arg #1, ID 1 : 1}} +// expected-remark@below {{Region #0, Arg #2, ID 2 : 2}} +func @while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) -> tensor { + %0 = constant dense : tensor + return %0 : tensor +} + +// ----- +// Test alias propagation through calls. +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @aliasing_through_calls +func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2, 3}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2, 3}} + %vh1 = "tf.Identity"(%vh0) : (!tf_res) -> (!tf_res) + // expected-remark@below {{Result #0, ID 2 : Unknown}} + // expected-remark@below {{Result #1, ID 3 : 0, 1, 2, 3}} + %c:2 = call @passthru(%vh1) : (!tf_res) -> (!tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vx = "tf.VarHandleOp"() {container = "cf", shared_name = "vx"} : () -> !tf_res + return %vx, %arg0 : !tf_res, !tf_res +} + +// ----- +// Test aliasing through IfRegion + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @if_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 7 : 1, 4, 6, 7}} +// expected-remark@below {{Region #0, Arg #1, ID 8 : 1, 2, 4, 5, 6, 8}} +func @if_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 3, 4, 5}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor + // expected-remark@below {{Result #0, ID 4 : Unknown}} + // expected-remark@below {{Result #1, ID 5 : 0, 1, 2, 3, 4, 5, 6, 8}} + // expected-remark@below {{Result #2, ID 6 : 1, 2, 4, 5, 6, 7, 8}} + %if:3 = "tf.IfRegion"(%read0) ({ + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 1, 2, 4, 5, 6, 8}} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + "tf.Yield"(%u0, %id0, %id0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + // expected-remark@below {{Result #0, ID 3 : 0, 1, 3, 4, 5}} + %id0 = "tf.Identity"(%vh0) : (!tf_res) -> !tf_res + "tf.Yield"(%id0, %id0, %arg0) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = true} : (tensor) -> (!tf_res, !tf_res, !tf_res) + return +} + +// ----- +// Test aliasing through CaseRegion + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @case_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 7 : 1, 4, 6, 7}} +// expected-remark@below {{Region #0, Arg #1, ID 8 : 1, 2, 4, 5, 6, 8}} +func @case_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 3, 4, 5}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor + // expected-remark@below {{Result #0, ID 4 : Unknown}} + // expected-remark@below {{Result #1, ID 5 : 0, 1, 2, 3, 4, 5, 6, 8}} + // expected-remark@below {{Result #2, ID 6 : 1, 2, 4, 5, 6, 7, 8}} + %if:3 = "tf.CaseRegion"(%read0) ({ + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 1, 2, 4, 5, 6, 8}} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + "tf.Yield"(%u0, %id0, %id0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + // expected-remark@below {{Result #0, ID 3 : 0, 1, 3, 4, 5}} + %id0 = "tf.Identity"(%vh0) : (!tf_res) -> !tf_res + "tf.Yield"(%id0, %id0, %arg0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + "tf.Yield"(%vh0, %arg1, %arg1) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = true} : (tensor) -> (!tf_res, !tf_res, !tf_res) + return +} + +// ----- +// Test aliasing through WhileRegion +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @while_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 11 : 1, 8, 11}} +// expected-remark@below {{Region #0, Arg #1, ID 12 : 1, 8, 9, 10, 12}} +// expected-remark@below {{Region #0, Arg #2, ID 13 : 1, 8, 9, 10, 13}} +func @while_region_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 8}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 8 : Unknown}} + // expected-remark@below {{Result #1, ID 9 : 1, 8, 9, 10, 12, 13}} + // expected-remark@below {{Result #2, ID 10 : 1, 8, 9, 10, 12, 13}} + // expected-remark@below {{Region #0, Arg #0, ID 2 : 1, 2, 8}} + // expected-remark@below {{Region #0, Arg #1, ID 3 : 1, 3, 8}} + // expected-remark@below {{Region #0, Arg #2, ID 4 : 1, 4, 8}} + // expected-remark@below {{Region #1, Arg #0, ID 5 : 1, 5, 8}} + // expected-remark@below {{Region #1, Arg #1, ID 6 : 1, 6, 8}} + // expected-remark@below {{Region #1, Arg #2, ID 7 : 1, 7, 8}} + %w:3 = "tf.WhileRegion"(%arg0, %arg1, %arg2) ({ + ^bb0(%carg0: !tf_res, %carg1: !tf_res, %carg2: !tf_res): + %0 = constant dense : tensor + "tf.Yield"(%0) : (tensor) -> () + },{ + ^bb0(%barg0: !tf_res, %barg1: !tf_res, %barg2: !tf_res): + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + "tf.Yield"(%u0, %barg2, %barg1) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = false} : (!tf_res, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// ----- +// Test aliasing through calls +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_calls +func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 0, 1, 2}} + %c:2 = call @passthru(%vh0) : (!tf_res) -> (!tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + return %vh0, %arg0 : !tf_res, !tf_res +} + +// ----- +// Test aliasing through tf_device.launch +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_launch +func @aliasing_through_launch(%arg0: tensor<32xf32>) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res + + // expected-remark@below {{Result #0, ID 1 : 0, 1}} + %launch = "tf_device.launch"() ({ + tf_device.return %vh : !tf_res + }) {device = ""} : () -> !tf_res + return +} + +// ----- +// Test aliasing through tf_device.cluster +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_cluster +func @aliasing_through_cluster(%arg0: tensor<32xf32>) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res + + // expected-remark@below {{Result #0, ID 1 : 0, 1}} + %cluster = "tf_device.cluster"() ({ + tf_device.return %vh : !tf_res + }) : () -> !tf_res + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir index a4a7c1dad2e..75cafde88e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -1,31 +1,33 @@ // RUN: tf-opt -split-input-file -verify-diagnostics -tf-resource-device-inference %s | FileCheck %s +!tf_res = type tensor<*x!tf.resource>> + // Tests that the pass can correctly propagate device attributes inside the same // function. // CHECK-LABEL: func @propagate_in_function func @propagate_in_function( - %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, - %arg1: tensor<*x!tf.resource>> {tf.device = "/TPU:1"}) { + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: !tf_res {tf.device = "/TPU:1"}) { tf_executor.graph { // CHECK: tf_executor.island %island = tf_executor.island { // CHECK-NEXT: "tf.VarHandleOp" %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"} - : () -> tensor<*x!tf.resource>> + : () -> !tf_res // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id1 = "tf.Identity"(%id0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id1 = "tf.Identity"(%id0) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/CPU:0"} - %id2 = "tf.Identity"(%var_handle) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> - %read = "tf.ReadVariableOp"(%id2) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %id2 = "tf.Identity"(%var_handle) : (!tf_res) + -> !tf_res + %read = "tf.ReadVariableOp"(%id2) : (!tf_res) -> tensor<32xf32> %id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32> tf_executor.yield } @@ -35,30 +37,31 @@ func @propagate_in_function( } // ----- +!tf_res = type tensor<*x!tf.resource>> // Tesets that the pass can propagate through tf.If's branches. // CHECK-LABEL: func @propagate_if_op func @propagate_if_op( - %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg0: !tf_res {tf.device = "/TPU:0"}, %arg1: tensor) { tf_executor.graph { // CHECK: tf_executor.island %island = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.VarHandleOp" %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} - : () -> tensor<*x!tf.resource>> + : () -> !tf_res // CHECK-NEXT: "tf.If" "tf.If"(%arg1, %id0, %var_handle) { then_branch = @if_then, else_branch = @if_else, is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) -> () + : (tensor, !tf_res, + !tf_res) -> () tf_executor.yield } tf_executor.fetch %island : !tf_executor.control @@ -68,19 +71,19 @@ func @propagate_if_op( // CHECK-LABEL: func @if_then func @if_then( - %arg0: tensor<*x!tf.resource>>, - %arg1: tensor<*x!tf.resource>>) { + %arg0: !tf_res, + %arg1: !tf_res) { tf_executor.graph { // CHECK: tf_executor.island %island = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:1"} - %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id1 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res tf_executor.yield } tf_executor.fetch %island : !tf_executor.control @@ -90,15 +93,15 @@ func @if_then( // CHECK-LABEL: func @if_else func @if_else( - %arg0: tensor<*x!tf.resource>>, - %arg1: tensor<*x!tf.resource>>) { + %arg0: !tf_res, + %arg1: !tf_res) { tf_executor.graph { // CHECK: tf_executor.island %island = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res tf_executor.yield } tf_executor.fetch %island : !tf_executor.control @@ -108,31 +111,31 @@ func @if_else( // ----- +!tf_res = type tensor<*x!tf.resource>> // Tesets that the pass can propagate through tf.While's branches. - // CHECK-LABEL: func @propagate_while_op func @propagate_while_op( - %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg0: !tf_res {tf.device = "/TPU:0"}, %arg1: tensor) { tf_executor.graph { // CHECK: tf_executor.island %island = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.VarHandleOp" %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} - : () -> tensor<*x!tf.resource>> + : () -> !tf_res // CHECK-NEXT: "tf.While" "tf.While"(%arg1, %id0, %var_handle) { body = @while_body, cond = @while_cond, is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) -> - (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) + : (tensor, !tf_res, + !tf_res) -> + (tensor, !tf_res, + !tf_res) tf_executor.yield } tf_executor.fetch %island : !tf_executor.control @@ -143,48 +146,48 @@ func @propagate_while_op( // CHECK-LABEL: func @while_body func @while_body( %arg0: tensor, - %arg1: tensor<*x!tf.resource>>, - %arg2: tensor<*x!tf.resource>>) -> - (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) { + %arg1: !tf_res, + %arg2: !tf_res) -> + (tensor, !tf_res, + !tf_res) { %graph:3 = tf_executor.graph { // CHECK: tf_executor.island %island:4 = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:1"} - %id1 = "tf.Identity"(%arg2) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id1 = "tf.Identity"(%arg2) : (!tf_res) + -> !tf_res tf_executor.yield %arg0, %id0, %id1 - : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>> + : tensor, !tf_res, + !tf_res } tf_executor.fetch %island#0, %island#1, %island#2 - : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>> + : tensor, !tf_res, + !tf_res } return %graph#0, %graph#1, %graph#2 - : tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>> + : tensor, !tf_res, + !tf_res } // CHECK-LABEL: func @while_cond func @while_cond( %arg0: tensor, - %arg1: tensor<*x!tf.resource>>, - %arg2: tensor<*x!tf.resource>>) -> tensor<32xf32> { + %arg1: !tf_res, + %arg2: !tf_res) -> tensor<32xf32> { %graph = tf_executor.graph { // CHECK: tf_executor.island %island:2 = tf_executor.island { // CHECK-NEXT: "tf.Identity" // CHECK-SAME: {device = "/TPU:0"} - %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res %read = "tf.ReadVariableOp"(%id0) - : (tensor<*x!tf.resource>>) -> tensor<32xf32> + : (!tf_res) -> tensor<32xf32> tf_executor.yield %read : tensor<32xf32> } tf_executor.fetch %island#0 : tensor<32xf32> @@ -193,31 +196,32 @@ func @while_cond( } // ----- +!tf_res = type tensor<*x!tf.resource>> // Tesets that the pass reports error on conflicting assignments from multiple // callers. func @error_on_conflict_multiple_callers( - %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg0: !tf_res {tf.device = "/TPU:0"}, %arg1: tensor) { tf_executor.graph { %island = tf_executor.island { - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} - : () -> tensor<*x!tf.resource>> + : () -> !tf_res "tf.If"(%arg1, %id0, %var_handle) { then_branch = @if_then_and_else, else_branch = @if_then_and_else, is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) -> () + : (tensor, !tf_res, + !tf_res) -> () "tf.If"(%arg1, %var_handle, %id0) { // expected-error@above {{Conflicting device assignment for resource}} then_branch = @if_then_and_else, else_branch = @if_then_and_else, is_stateless = false} - : (tensor, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>) -> () + : (tensor, !tf_res, + !tf_res) -> () tf_executor.yield } tf_executor.fetch %island : !tf_executor.control @@ -226,17 +230,311 @@ func @error_on_conflict_multiple_callers( } func @if_then_and_else( - %arg0: tensor<*x!tf.resource>>, - %arg1: tensor<*x!tf.resource>>) { + %arg0: !tf_res, + %arg1: !tf_res) { tf_executor.graph { %island = tf_executor.island { - %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> - %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) - -> tensor<*x!tf.resource>> + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + %id1 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res tf_executor.yield } tf_executor.fetch %island : !tf_executor.control } return } + +// ----- + +// Test that the pass can propagate through calls +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @test_function +// CHECK-SAME: {tf.device = "/TPU:0"} +func @test_function(%arg0: !tf_res) { + // CHECK: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + %read = "tf.ReadVariableOp"(%id0) : (!tf_res) -> tensor<32xf32> + %cst = constant dense<3.0> : tensor<32xf32> + %add = "tf.AddV2"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %add) : (!tf_res, tensor<32xf32>) -> () + return +} + +// CHECK-LABEL: func @propagate_through_calls +func @propagate_through_calls( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: !tf_res {tf.device = "/TPU:1"}) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"} + : () -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id1 = "tf.Identity"(%id0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/CPU:0"} + %id2 = "tf.Identity"(%var_handle) : (!tf_res) + -> !tf_res + %read = "tf.ReadVariableOp"(%id2) : (!tf_res) -> tensor<32xf32> + %id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32> + call @test_function(%id1) : (!tf_res) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// Test propagation through IfRegion (with non-inlined calls) +// CHECK-LABEL: func @propagate_if_region +func @propagate_if_region( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} + : () -> !tf_res + // CHECK-NEXT: "tf.IfRegion" + "tf.IfRegion"(%arg1) ({ + call @ifregion_then(%id0, %var_handle) : (!tf_res, !tf_res) -> () + "tf.Yield"() : () -> () + }, { + call @ifregion_else(%id0, %var_handle) : (!tf_res, !tf_res) -> () + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @ifregion_then +// CHECK-SAME: (%arg0: {{.+}} {tf.device = "/TPU:0"}, %arg1: {{.+}} {tf.device = "/TPU:1"} +func @ifregion_then( + %arg0: !tf_res, + %arg1: !tf_res) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @ifregion_else +// CHECK-SAME: (%arg0: {{.+}} {tf.device = "/TPU:0"}, %arg1: {{.+}} {tf.device = "/TPU:1"} +func @ifregion_else( + %arg0: !tf_res, + %arg1: !tf_res) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg1) : (!tf_res) + -> !tf_res + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// Test progagation through IfRegion (inlined calls) +// CHECK-LABEL: func @propagate_if_region_inlined +func @propagate_if_region_inlined( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) + -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} + : () -> !tf_res + // CHECK-NEXT: "tf.IfRegion" + "tf.IfRegion"(%arg1) ({ + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id1 = "tf.Identity"(%id0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id2 = "tf.Identity"(%var_handle) : (!tf_res) -> !tf_res + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + "tf.Yield"() : () -> () + }, { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id1 = "tf.Identity"(%id0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id2 = "tf.Identity"(%var_handle) : (!tf_res) -> !tf_res + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// Test propagation through WhileRegion (inlined calls) +// CHECK-LABEL: func @propagate_while_region_inlined +func @propagate_while_region_inlined( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res + // CHECK-NEXT: "tf.WhileRegion" + "tf.WhileRegion"(%arg1, %id0, %var_handle) ({ + ^bb0(%carg0: tensor, %carg1: !tf_res, %carg2: !tf_res): + // CHECK: ^bb + // CHECK: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %cid0 = "tf.Identity"(%carg1) : (!tf_res) -> !tf_res loc("cid0") + %read = "tf.ReadVariableOp"(%cid0) : (!tf_res) -> tensor<32xf32> + %cst = constant dense<3.0> : tensor<32xf32> + %cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1> + %dims = constant dense<0> : tensor<1xi32> + %reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor + "tf.Yield"(%reduce) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: !tf_res, %barg2: !tf_res): + // CHECK: ^bb + // CHECK: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %bid0 = "tf.Identity"(%barg1) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%barg2) : (!tf_res) -> !tf_res + "tf.Yield"(%barg0, %bid0, %id1) : (tensor, !tf_res,!tf_res) -> () + }){is_stateless = false} + : (tensor, !tf_res, !tf_res) -> (tensor, !tf_res, !tf_res) + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// Test propagation through WhileRegion (non-inlined calls) +// CHECK-LABEL: func @propagate_while_region +func @propagate_while_region( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res + // CHECK-NEXT: "tf.WhileRegion" + "tf.WhileRegion"(%arg1, %id0, %var_handle) ({ + ^bb0(%carg0: tensor, %carg1: !tf_res, %carg2: !tf_res): + %cond = call @whileregion_cond(%carg0, %carg1, %carg2) : (tensor, !tf_res, !tf_res) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: !tf_res, %barg2: !tf_res): + %new_values:3 = call @whileregion_body(%barg0, %barg1, %barg2) : (tensor, !tf_res,!tf_res) -> (tensor, !tf_res,!tf_res) + "tf.Yield"(%new_values#0, %new_values#1, %new_values#2) : (tensor, !tf_res,!tf_res) -> () + }){is_stateless = false} + : (tensor, !tf_res, !tf_res) -> (tensor, !tf_res, !tf_res) + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @whileregion_body +func @whileregion_body(%arg0: tensor, %arg1: !tf_res, %arg2: !tf_res) -> (tensor, !tf_res, !tf_res) { + %graph:3 = tf_executor.graph { + // CHECK: tf_executor.island + %island:4 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg2) : (!tf_res) -> !tf_res + tf_executor.yield %arg0, %id0, %id1 : tensor, !tf_res, !tf_res + } + tf_executor.fetch %island#0, %island#1, %island#2 : tensor, !tf_res, !tf_res + } + return %graph#0, %graph#1, %graph#2: tensor, !tf_res, !tf_res +} + +// CHECK-LABEL: func @whileregion_cond +func @whileregion_cond(%arg0: tensor, %arg1: !tf_res, %arg2: !tf_res) -> tensor { + %graph = tf_executor.graph { + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + %read = "tf.ReadVariableOp"(%id0) : (!tf_res) -> tensor<32xf32> + %cst = constant dense<3.0> : tensor<32xf32> + %cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1> + %dims = constant dense<0> : tensor<1xi32> + %reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor + tf_executor.yield %reduce : tensor + } + tf_executor.fetch %island#0 : tensor + } + return %graph : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index ac5c2df8f7e..8457d9c62cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) // CHECK: "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]] @@ -39,7 +39,7 @@ func @only_resource_store() -> tensor<*xi32> { // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) %1 = "tf_device.cluster"() ( { %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>) @@ -61,13 +61,13 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> @@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> { // ----- -// Tests that pass fails when there are remaining resource operationss that can -// not be lifted. - -func @lifting_failure() -> tensor<*xi32> { - - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // expected-error @+1 {{has remaining resource inputs that can not be lifted}} - %1 = "tf_device.cluster"() ( { - %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> - %3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32> - "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () - tf_device.return %3 : tensor<*xi32> - }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - - return %1 : tensor<*xi32> -} - -// ----- - // Tests that pass lifts resource reads/writes from a loop, and removed unused // resources. @@ -328,6 +308,7 @@ func @while_cond1(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!t func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { + // expected-error@+1 {{result #0 not tied to function argument for branch @while_body}} %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) @@ -337,7 +318,6 @@ func @cluster_with_loop() -> () { } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> - // expected-error @+1 {{resource used in while loop is only supported when the resource input and output alias each other in the loop body}} return %0 : tensor<*x!tf.resource>> } func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { @@ -347,35 +327,12 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // ----- -// Tests that pass reports error on unsupported ops in loop body. - -func @cluster_with_loop() -> () { - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.cluster"() ( { - %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false} - : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) - tf_device.return - }) {cluster_attr = "cluster_attr"} : () -> () - return -} -func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { - // expected-error @+1 {{found unsupported operations on resource.}} - "tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource>>) -> () - return %arg0 : tensor<*x!tf.resource>> -} -func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { - %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor - return %read : tensor -} - -// ----- - // Tests that pass reports error on unsupported ops in loop cond. func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { + // expected-error@+1 {{found resource write in loop condition.}} %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) @@ -391,7 +348,6 @@ func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.re func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor %constant = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor - // expected-error @+1 {{found resource write in loop condition.}} "tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource>>, tensor) -> () return %read : tensor } @@ -409,7 +365,7 @@ func @cluster_with_case(%arg0: tensor) -> tensor<4xf32> { // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() %2 = "tf_device.cluster"() ( { // CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]]) - %3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]} + %3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<4xf32>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0) @@ -571,7 +527,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf_device.cluster"() ( { - // expected-error @+1 {{unsupported output: resource does not alias a single input}} + // expected-error @+1 {{result #0 is not tied to the same argument across all branches}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) @@ -598,7 +554,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf_device.cluster"() ( { - // expected-error @+1 {{unsupported output: resource does not alias input}} + // expected-error @+1 {{result #0 not tied to function argument for branch @if_then}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) @@ -757,3 +713,381 @@ func @callee(%arg0: tensor<*x!tf.resource>>) -> tensor { // CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor) -> tensor // CHECK-NEXT: return %[[A0]] + +// ----- + +// Test that the pass can lift resources out of IfRegion +// CHECK: func @cluster_with_ifregion(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_ifregion(%arg0: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[IF:.*]]:2 = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.IfRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: "tf.Yield"(%[[READ1]], %[[READ1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// Test that the pass can lift resources out of CaseRegion +// CHECK: func @cluster_with_caseregion(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_caseregion(%arg0: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[CASE:.*]]:2 = "tf.CaseRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.CaseRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: "tf.Yield"(%[[READ1]], %[[READ1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: %[[CONST1:.*]] = "tf.Const" + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[READ1]], %[[CONST1]]) + // CHECK: "tf.Yield"(%[[READ1]], %[[SUB]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %constant = "tf.Const"() {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %sub = "tf.Sub"(%read, %constant) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %sub) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[CASE]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// ----- + +// Test that the pass can lift resources out of WhileRegion +// CHECK-LABEL: func @cluster_with_whileregion +func @cluster_with_whileregion() -> () { + // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor} + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]]) + %0 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ( { + %2:3 = "tf.WhileRegion"(%0, %1, %unused) ({ + // CHECK: (%[[CARG0:.+]]: tensor, %[[CARG1:.+]]: tensor): + // CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]]) + // CHECK: "tf.Less"(%[[CARG0]], %[[CAST]]) + // CHECK: "tf.Yield" + ^bb0(%carg0: tensor, %carg1:tensor<*x!tf.resource>>, %carg2: tensor<*x!tf.resource>>): + %read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource>>) -> tensor + %cast = "tf.Cast"(%read0) : (tensor) -> tensor + %cond = "tf.Less"(%carg0, %cast) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + // CHECK: (%[[BARG0:.+]]: tensor, %[[BARG1:.+]]: tensor): + // CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[BARG1]], %[[BARG1]]) + // CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[ADD0]]) + // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor} + // CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]]) + // CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]]) + ^bb1(%barg0: tensor, %barg1:tensor<*x!tf.resource>>, %barg2: tensor<*x!tf.resource>>): + %read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource>>) -> tensor + %add0 = "tf.AddV2"(%read0, %read0) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource>>, tensor) -> () + %read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource>>) -> tensor + %add1 = "tf.AddV2"(%read1, %read1) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource>>, tensor) -> () + %constant = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %add2 = "tf.AddV2"(%barg0, %constant) : (tensor, tensor) -> tensor + %id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + "tf.Yield"(%add2, %barg1, %id) : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () + }) {device = "", is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: tf_device.return %[[WHILE]]#1 : tensor + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) + // CHECK: return + return +} + +// ----- + +// Test that the pass can lift out recursively (If with another if it its body) +// CHECK: func @cluster_with_if_within_if(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_if_within_if(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + // CHECK: %[[IF:.*]]:2 = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + %3:2 = "tf.IfRegion"(%arg0) ({ + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + // CHECK-NEXT: "tf.Yield"(%[[CONST]], %[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }, { + // CHECK: %[[IF1:.*]] = "tf.IfRegion" + // CHECK: "tf.Yield"(%[[READ1]]) + // CHECK: "tf.Yield"(%[[READ0]]) + // CHECK: "tf.Yield"(%[[IF1]], %[[IF1]]) + %id = "tf.Identity"(%1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.IfRegion"(%arg1) ({ + %read_then = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.Yield"(%read_then) : (tensor<4xf32>) -> () + }, { + %read_else = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.Yield"(%read_else) : (tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"(%0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + }) {is_stateless = false} : (tensor) -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) + // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + tf_device.return %5 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} + +// ----- + +// IfRegion with store in just one branch + +// CHECK: func @if_region_with_store_in_then(%[[ARG0:.*]]: tensor) +func @if_region_with_store_in_then(%arg0: tensor) { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[IF:.*]] = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ({ + "tf.IfRegion"(%arg0) ({ + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }, { + // CHECK: "tf.Yield"(%[[READ]]) + "tf.Yield"() : () -> () + }) { is_stateless = true} : (tensor) -> () + tf_device.return + }) { cluster_attr = "cluster_attr" } : () -> () + // CHECK: tf_device.return %[[IF]] + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) + return +} + +// ----- + +// IfRegion with store in both branches + +// CHECK: func @if_region_with_store_in_both(%[[ARG0:.*]]: tensor) +func @if_region_with_store_in_both(%arg0: tensor) { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + // CHECK: %[[IF:.*]] = "tf.IfRegion"(%[[ARG0]]) + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + "tf_device.cluster"() ({ + "tf.IfRegion"(%arg0) ({ + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }, { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> + // CHECK: "tf.Yield"(%[[CONST]]) + %constant = "tf.Const"() {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + "tf.Yield"() : () -> () + }) { is_stateless = true} : (tensor) -> () + tf_device.return + }) { cluster_attr = "cluster_attr" } : () -> () + // CHECK: tf_device.return %[[IF]] + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) + return +} + + +// Make sure unsupported resources are handled correctly. If a resource is used +// in an unsupported op, resource op lifting should skip lifting that resource. +// So for the below test, the IR should stay unchanged. +// CHECK-LABEL: func @test_unsupported_resource_op +func @test_unsupported_resource_op() -> tensor<*xi32> { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf_device.cluster"() ( { + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.SomeResourceOperation" + // CHECK: "tf.SomeComputation" + // CHECK: tf_device.return + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK: return + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> + "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> () + %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + return %1 : tensor<*xi32> +} + +// Test unsupported use of resource ops in functional control flow. In the test +// below, arg0 has an unsupported use whereas arg1 does not. So we expect arg0 +// to not be lifted and arg1 to be lifted. +// CHECK-LABEL: func @test_unsupported_resource_op_in_if +func @test_unsupported_resource_op_in_if(%arg0: tensor) -> tensor<*xi32> { + // CHECK: [[VH0:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} + // CHECK: [[VH1:%.*]] = "tf.VarHandleOp"() {container = "d", shared_name = "w"} + // CHECK-NOT: "tf.ReadVariableOp"([[VH0]]) + // CHECK: [[READ1:%.*]] = "tf.ReadVariableOp"([[VH1]]) + // CHECK-NOT: "tf.ReadVariableOp"([[VH0]]) + // CHECK: "tf_device.cluster"() ( { + // CHECK: "tf.If"({{%.*}}, [[VH0]], [[READ1]]) + // CHECK-SAME: else_branch = @else_fn, is_stateless = true, then_branch = @then_fn + // CHECK: tf_device.return + // CHECK: return + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource> + %2 = "tf_device.cluster"() ( { + %3 = "tf.If"(%arg0, %0, %1) + { else_branch = @else_fn, then_branch = @then_fn, is_stateless = true} + : (tensor, tensor<*x!tf.resource>, tensor<*x!tf.resource>) -> tensor<*xi32> + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + return %2 : tensor<*xi32> +} + +// CHECK-LABEL: func @else_fn +// CHECK-SAME: (%{{.*}}: tensor<*x!tf.resource>, %{{.*}}: tensor<*xi32>) +func @else_fn(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<*xi32> { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} + +// CHECK-LABEL: func @then_fn +// CHECK-SAME: (%{{.*}}: tensor<*x!tf.resource>, %{{.*}}: tensor<*xi32>) +func @then_fn(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<*xi32> { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + "tf.UnsupportedResourceOp"(%arg0) : (tensor<*x!tf.resource>) -> () + return %2 : tensor<*xi32> +} + +// Test type refinement. If the resource has a single subtype, check that that +// type gets used when hoisting the read. None of the result types will change. +// CHECK-LABEL: func @type_refinement_use_subtype +func @type_refinement_use_subtype() -> tensor<*xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<*xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<*xi32>, tensor<*xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor<*xi32> + %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<*xi32>) -> () + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<*xi32> + return %1 : tensor<*xi32> +} + +// If multiple types are used across reads and writes, check that the read uses +// the most refined type. The first ReadVariable should refine the type from +// *xi32 to ?xi32 and the assign should refine it further to 4xi32. +// CHECK-LABEL: func @type_refinement_use_refined_type +func @type_refinement_use_refined_type() -> tensor<4xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<4xi32>, tensor<4xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<4xi32>, tensor<4xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor + %3 = "tf.SomeComputation"(%2) : (tensor) -> (tensor<4xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<4xi32>) -> () + tf_device.return %3 : tensor<4xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<4xi32> + return %1 : tensor<4xi32> +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 4a5e3c8deaa..26df60229e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -2,69 +2,69 @@ // RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { -// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> + // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK: %[[RESULT:.*]] = "tf.AddV2" - // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: return %[[RESULT]] : tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2" + // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> %2 = "tf.AddV2"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> return %2 : tensor<*xi32> } -// CHECK-LABEL: func @simple_chain + // CHECK-LABEL: func @simple_chain func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: return %[[ADD]] : tensor<1xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: return %[[ADD]] : tensor<1xf32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> } -// CHECK-LABEL: func @simple_chain_with_broadcast + // CHECK-LABEL: func @simple_chain_with_broadcast func @simple_chain_with_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<10xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32> -// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> -// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<10xf32>) -> tensor<*xf32> -// CHECK: %[[UNKNOWN:.*]] = addf %[[CAST]], %[[CAST]] : tensor<*xf32> -// CHECK: return %[[UNKNOWN]] : tensor<*xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<10xf32>) -> tensor<*xf32> + // CHECK: %[[UNKNOWN:.*]] = addf %[[CAST]], %[[CAST]] : tensor<*xf32> + // CHECK: return %[[UNKNOWN]] : tensor<*xf32> %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<10xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %2 = addf %1, %1 : tensor<*xf32> return %2 : tensor<*xf32> } -// CHECK-LABEL: func @unknown_op + // CHECK-LABEL: func @unknown_op func @unknown_op(%arg0: tensor<1xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[UNKNOWN:.*]] = "tf.Unknown"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> -// CHECK: return %[[UNKNOWN]] : tensor<*xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: %[[UNKNOWN:.*]] = "tf.Unknown"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: return %[[UNKNOWN]] : tensor<*xf32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> %1 = "tf.Unknown"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> } -// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor) -> tensor -func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { - br ^bb1 -^bb1: -// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor) -> tensor -// CHECK: return %[[IDENTITY]] : tensor - %ret = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> - return %ret : tensor<*xf32> -} + // CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor) -> tensor + func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { + br ^bb1 + ^bb1: + // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return %[[IDENTITY]] : tensor + %ret = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> + return %ret : tensor<*xf32> + } -// Tests the case where an inference opportunity relies on folding. + // Tests the case where an inference opportunity relies on folding. -// CHECK-LABEL: func @simple_folding + // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { -// CHECK: %[[SHAPE:.*]] = "tf.Shape" -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] -// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> -// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Shape" + // CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] + // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) { padding = "VALID", strides = [1, 1, 1, 1] @@ -72,7 +72,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1 : tensor } -// Tests where tf.Const's value needs to be refined. + // Tests where tf.Const's value needs to be refined. func @const_refine() -> tensor<*xi32> { %0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32> @@ -81,9 +81,9 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xi32> } -// Tests the case where an op's shape function returns non-fully-defined shapes. + // Tests the case where an op's shape function returns non-fully-defined shapes. -// CHECK-LABEL: func @op_non_fully_defined_shape_fn + // CHECK-LABEL: func @op_non_fully_defined_shape_fn func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor { // CHECK: tf.BroadcastGradientArgs // CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) @@ -91,7 +91,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %2#0 : tensor } -// CHECK-LABEL: func @shape_from_const_input + // CHECK-LABEL: func @shape_from_const_input func @shape_from_const_input(%arg0: tensor<3x3x32x64xf32>, %arg1: tensor<200x24x24x64xf32>) -> tensor { %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: tf.Conv2DBackpropInput @@ -223,7 +223,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-SAME: %[[ARG_1:.*]]: tensor>> func @shape_from_case_to_branch_functions(%arg0: tensor, %arg1: tensor>>) -> tensor<1x2x3xf32> { // CHECK: %[[CASE:.*]] = "tf.Case"(%[[ARG_0]], %[[ARG_1]]) - %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1]} : (tensor, tensor>>) -> tensor<1x2x3xf32> + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1], is_stateless = false} : (tensor, tensor>>) -> tensor<1x2x3xf32> // CHECK: return %[[CASE]] : tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -530,6 +530,21 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32> } + // CHECK-LABEL: infer_device_cluster + func @infer_device_cluster(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf_device.cluster"() ({ + %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32> + tf_device.return %2 : tensor<1x8x2xf32> + // CHECK: () -> tensor<1x8x2xf32> + }) : () -> tensor<*xf32> + // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32> + // CHECK: (tensor, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>) + %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) + %4 = addf %1, %1 : tensor<*xf32> + return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32> + } + // CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32> func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> { // CHECK: %[[RESULT:.*]] = tensor_cast diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index d79e028ba9e..5eacbdea180 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -277,7 +277,7 @@ func @with_replicate( // ----- -// Tests that the pass does not add control dependencies a stateless if op. +// Tests that the pass does not add control dependencies for a stateless if op. // CHECK-LABEL: func @stateless_if_op func @stateless_if_op( @@ -361,6 +361,83 @@ func @if_else(%arg0: tensor) -> tensor { // ----- +// Tests that the pass does not add control dependencies for a stateless +// IfRegion op. + +// CHECK-LABEL: func @stateless_ifregion_op +func @stateless_ifregion_op( + // expected-remark@above {{ID: 18}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 16}} + // expected-remark@above {{Successors: {17}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 14}} + // expected-remark@above {{Successors: {15}}} + + %r0 = "tf.ReadVariableOp"(%arg0) : + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {12}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + + %if = "tf.IfRegion"(%arg1) ( + // expected-remark@above {{ID: 11}} + { // Then region. + %graph = tf_executor.graph { + // expected-remark@above {{ID: 4}} + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {3}}} + tf_executor.yield %arg1 : tensor + // expected-remark@above {{ID: 1}} + } + tf_executor.fetch %island#0 : tensor + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {2}}} + } + "tf.Yield"(%graph) : (tensor) -> () + // expected-remark@above {{ID: 5}} + }, { // Else region + %graph = tf_executor.graph { + // expected-remark@above {{ID: 9}} + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield %arg1 : tensor + // expected-remark@above {{ID: 6}} + } + tf_executor.fetch %island#0 : tensor + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + } + "tf.Yield"(%graph) : (tensor) -> () + // expected-remark@above {{ID: 10}} + } + ) { is_stateless = true} : (tensor) -> tensor + + "tf.AssignVariableOp"(%arg0, %r0) : + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {13}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + + tf_executor.yield + // expected-remark@above {{ID: 13}} + // expected-remark@above {{Predecessors: {12}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 15}} + // expected-remark@above {{Predecessors: {14}}} + } + return + // expected-remark@above {{ID: 17}} + // expected-remark@above {{Predecessors: {16}}} +} + +// ----- + // Tests that the pass does not add control dependencies a stateless while op. // CHECK-LABEL: func @stateless_if_op @@ -379,7 +456,7 @@ func @stateless_if_op( // expected-remark@above {{ID: 0}} // expected-remark@above {{Successors: {2}}} (tensor<*x!tf.resource>>) -> tensor<32xf32> - %if = "tf.While"(%arg1) { + %while = "tf.While"(%arg1) { // expected-remark@above {{ID: 1}} body = @while_body, cond = @while_cond, is_stateless = true} : (tensor) -> tensor @@ -445,9 +522,98 @@ func @while_cond(%arg0: tensor) -> tensor { // ----- +// Tests that the pass does not add control dependencies a stateless WhileRegion +// op. + +// CHECK-LABEL: func @stateless_whileregion_op +func @stateless_whileregion_op( + // expected-remark@above {{ID: 18}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 16}} + // expected-remark@above {{Successors: {17}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 14}} + // expected-remark@above {{Successors: {15}}} + %r0 = "tf.ReadVariableOp"(%arg0) : + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {12}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + + %while = "tf.WhileRegion"(%arg1) ( + // expected-remark@above {{ID: 11}} + { + ^bb0(%carg: tensor): + %graph = tf_executor.graph { + // expected-remark@above {{ID: 4}} + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {3}}} + tf_executor.yield %carg : tensor + // expected-remark@above {{ID: 1}} + } + tf_executor.fetch %island#0 : tensor + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {2}}} + } + "tf.Yield"(%graph) : (tensor) -> () + // expected-remark@above {{ID: 5}} + }, { + ^bb0(%barg: tensor): + %graph = tf_executor.graph { + // expected-remark@above {{ID: 9}} + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield %barg : tensor + // expected-remark@above {{ID: 6}} + } + tf_executor.fetch %island#0 : tensor + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + } + "tf.Yield"(%graph) : (tensor) -> () + // expected-remark@above {{ID: 10}} + } + ) {is_stateless = true} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg0, %r0) : + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {13}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + tf_executor.yield + // expected-remark@above {{ID: 13}} + // expected-remark@above {{Predecessors: {12}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 15}} + // expected-remark@above {{Predecessors: {14}}} + } + return + // expected-remark@above {{ID: 17}} + // expected-remark@above {{Predecessors: {16}}} +} + +// ----- + // Tests that the pass tracks control dependencies for variables from an if op's // output. +// In this test, the resources computed and used are as follows: +// (* = unknown resource id which aliases with everything else) +// id0 = arg0 +// if-then-branch: [u0, arg0, arg0] +// if-else-branch: [arg0, arg0, arg1] +// => first result is unknown, second and third is passthrough +// if results : [*, arg0, {arg0, arg1}[ +// ID #2: read (unknown) -> succ {5, 6) +// ID #3: read (arg0) -> succ {5} +// ID #4: read({arg0,arg1}) -> succ {5,6} +// ID #5: write(arg0) +// ID #6: write(arg1) + // CHECK-LABEL: func @output_of_if_op func @output_of_if_op( // expected-remark@above {{ID: 12}} @@ -597,9 +763,151 @@ func @if_else( // ----- +// Tests that the pass tracks control dependencies for variables from an +// IfRegion op's output. + +// CHECK-LABEL: func @output_of_ifregion_op +func @output_of_ifregion_op( + // expected-remark@above {{ID: 26}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 24}} + // expected-remark@above {{Successors: {25}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 22}} + // expected-remark@above {{Successors: {23}}} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 0}} + -> tensor<*x!tf.resource>> + %if:3 = "tf.IfRegion"(%arg2) ( + // expected-remark@above {{ID: 15}} + // expected-remark@above {{Successors: {16,17,18}}} + { + %graph:3 = tf_executor.graph { + // expected-remark@above {{ID: 6}} + %island:4 = tf_executor.island { + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {5}}} + %u0 = "tf._UnknownSideEffectingOp_"() : () + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {3}}} + -> tensor<*x!tf.resource>> + %iid0 = "tf.Identity"(%id0) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 2}} + -> tensor<*x!tf.resource>> + tf_executor.yield %u0, %iid0, %iid0 : + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {1}}} + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + tf_executor.fetch %island#0, %island#1, %island#2 : + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + "tf.Yield"(%graph#0, %graph#1, %graph#2) : + // expected-remark@above {{ID: 7}} + (tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + }, + { + %graph:3 = tf_executor.graph { + // expected-remark@above {{ID: 13}} + %island:4 = tf_executor.island { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Successors: {12}}} + %iid0 = "tf.Identity"(%id0) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 8}} + -> tensor<*x!tf.resource>> + %iid1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 9}} + -> tensor<*x!tf.resource>> + tf_executor.yield %iid0, %iid0, %iid1 : + // expected-remark@above {{ID: 10}} + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + tf_executor.fetch %island#0, %island#1, %island#2 : + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + "tf.Yield"(%graph#0, %graph#1, %graph#2) : + // expected-remark@above {{ID: 14}} + (tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + }) { is_stateless = false} + : (tensor) -> + (tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) + %r0 = "tf.ReadVariableOp"(%if#0) : + // expected-remark@above {{ID: 16}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {19,20}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + %r1 = "tf.ReadVariableOp"(%if#1) : + // expected-remark@above {{ID: 17}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {19}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + %r2 = "tf.ReadVariableOp"(%if#2) : + // expected-remark@above {{ID: 18}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {19,20}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %r0) : + // expected-remark@above {{ID: 19}} + // expected-remark@above {{Predecessors: {16,17,18}}} + // expected-remark@above {{Successors: {21}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg1, %r0) : + // expected-remark@above {{ID: 20}} + // expected-remark@above {{Predecessors: {16,18}}} + // expected-remark@above {{Successors: {21}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + tf_executor.yield + // expected-remark@above {{ID: 21}} + // expected-remark@above {{Predecessors: {19,20}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 23}} + // expected-remark@above {{Predecessors: {22}}} + } + return + // expected-remark@above {{ID: 25}} + // expected-remark@above {{Predecessors: {24}}} +} + +// ----- + // Tests that the pass tracks control dependencies for variables from a while // op's output. +// Here: +// id0 = arg0 +// while-inputs = (id0/arg0, arg1, arg1) +// while body pass through first and second arg, not last one +// while-results = (arg0, arg1, Unknown) +// #ID 2: read(arg0) -> succ{5} +// #ID 3: read(arg1) -> succ{6} +// #ID 4: read(unknown) -> succ{5,6} +// #ID 5 : write(arg0) +// #ID 6 : write(arg1) + + // CHECK-LABEL: func @output_of_while_op func @output_of_while_op( // expected-remark@above {{ID: 12}} @@ -631,24 +939,24 @@ func @output_of_while_op( // expected-remark@above {{Predecessors: {1}}} // expected-remark@above {{Successors: {5}}} (tensor<*x!tf.resource>>) -> tensor<32xf32> - %r1 = "tf.ReadVariableOp"(%while#1) : + %r1 = "tf.ReadVariableOp"(%while#2) : // expected-remark@above {{ID: 3}} // expected-remark@above {{Predecessors: {1}}} - // expected-remark@above {{Successors: {5}}} - (tensor<*x!tf.resource>>) -> tensor<32xf32> - %r2 = "tf.ReadVariableOp"(%while#2) : - // expected-remark@above {{ID: 4}} - // expected-remark@above {{Predecessors: {1}}} // expected-remark@above {{Successors: {6}}} (tensor<*x!tf.resource>>) -> tensor<32xf32> + %r2 = "tf.ReadVariableOp"(%while#3) : + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {1}}} + // expected-remark@above {{Successors: {5,6}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> "tf.AssignVariableOp"(%arg0, %r0) : // expected-remark@above {{ID: 5}} - // expected-remark@above {{Predecessors: {2,3}}} + // expected-remark@above {{Predecessors: {2,4}}} // expected-remark@above {{Successors: {7}}} (tensor<*x!tf.resource>>, tensor<32xf32>) -> () "tf.AssignVariableOp"(%arg1, %r0) : // expected-remark@above {{ID: 6}} - // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Predecessors: {3,4}}} // expected-remark@above {{Successors: {7}}} (tensor<*x!tf.resource>>, tensor<32xf32>) -> () tf_executor.yield @@ -740,6 +1048,136 @@ func @while_cond( // ----- +// Tests that the pass tracks control dependencies for variables from a +// WhileRegion op's output. + +// CHECK-LABEL: func @output_of_whileregion_op +func @output_of_whileregion_op( + // expected-remark@above {{ID: 26}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 24}} + // expected-remark@above {{Successors: {25}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 22}} + // expected-remark@above {{Successors: {23}}} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 0}} + -> tensor<*x!tf.resource>> + %while:4 = "tf.WhileRegion"(%arg2, %id0, %arg1, %arg1) ( + // expected-remark@above {{ID: 15}} + // expected-remark@above {{Successors: {16,17,18}}} + { + ^bb0(%pred: tensor, + %carg1: tensor<*x!tf.resource>>, + %carg2: tensor<*x!tf.resource>>, + %carg3: tensor<*x!tf.resource>>): + %graph = tf_executor.graph { + // expected-remark@above {{ID: 6}} + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {5}}} + %const = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // expected-remark@above {{ID: 1}} + %eq = "tf.Equal"(%pred, %const) : (tensor, tensor) -> tensor + // expected-remark@above {{ID: 2}} + tf_executor.yield %eq : tensor + // expected-remark@above {{ID: 3}} + } + tf_executor.fetch %island#0 : tensor + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + } + "tf.Yield"(%graph) : (tensor) -> () + // expected-remark@above {{ID: 7}} + }, + { + ^bb0(%pred: tensor, + %barg0: tensor<*x!tf.resource>>, + %barg1: tensor<*x!tf.resource>>, + %barg2: tensor<*x!tf.resource>>): + %graph:4 = tf_executor.graph { + // expected-remark@above {{ID: 13}} + %island:5 = tf_executor.island { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Successors: {12}}} + %iid0 = "tf.Identity"(%barg0) : (tensor<*x!tf.resource>>) + // expected-remark@above {{ID: 8}} + -> tensor<*x!tf.resource>> + %u0 = "tf._UnknownSideEffectingOp_"() : () + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Successors: {10}}} + -> tensor<*x!tf.resource>> + tf_executor.yield %pred, %iid0, %barg1, %u0 : + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + tf_executor.fetch %island#0, %island#1, %island#2, %island#3 : + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + "tf.Yield"(%graph#0, %graph#1, %graph#2, %graph#3) : + // expected-remark@above {{ID: 14}} + (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + } + ) {is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> + (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) + %r0 = "tf.ReadVariableOp"(%while#1) : + // expected-remark@above {{ID: 16}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {19}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + %r1 = "tf.ReadVariableOp"(%while#2) : + // expected-remark@above {{ID: 17}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {20}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + %r2 = "tf.ReadVariableOp"(%while#3) : + // expected-remark@above {{ID: 18}} + // expected-remark@above {{Predecessors: {15}}} + // expected-remark@above {{Successors: {19,20}}} + (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %r0) : + // expected-remark@above {{ID: 19}} + // expected-remark@above {{Predecessors: {16,18}}} + // expected-remark@above {{Successors: {21}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg1, %r0) : + // expected-remark@above {{ID: 20}} + // expected-remark@above {{Predecessors: {17,18}}} + // expected-remark@above {{Successors: {21}}} + (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + tf_executor.yield + // expected-remark@above {{ID: 21}} + // expected-remark@above {{Predecessors: {19,20}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 23}} + // expected-remark@above {{Predecessors: {22}}} + } + return + // expected-remark@above {{ID: 25}} + // expected-remark@above {{Predecessors: {24}}} +} + +// ----- + // Tests that the pass tracks control dependencies based on TF op registry // statefulness flag, for ops not yet defined in ODS. @@ -824,4 +1262,3 @@ func @arguments_with_unique_ids( // expected-remark@above {{ID: 8}} // expected-remark@above {{Predecessors: {7}}} } - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 3d187aa5d60..92cb0458bf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -256,7 +256,7 @@ func @main(%arg0: tensor) -> () { %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> - %case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2]} + %case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false} : (tensor, tensor>>) -> tensor>> // CHECK: "tf.Slice" %pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 20a0e22c48e..1d5e6aad982 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -183,6 +183,20 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> { // ----- +// Test tf.Min with complex numbers. +// Previous versions of tensorflow said complex numbers were allowed with +// tf.Min even though it doesn't make sense. The legalization of tf to xla +// requires that complex types are not allowed in tf.Min, so we have an +// explicit unit here to make sure that invariant is enforced. +func @testMinComplex(%arg0: tensor<4x8xcomplex>) -> tensor<4x1xcomplex> { + %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> + // expected-error@below {{'tf.Min' op operand #0 must be tensor of}} + %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xcomplex>, tensor<1xi64>) -> tensor<4x1xcomplex> + return %0 : tensor<4x1xcomplex> +} + +// ----- + // CHECK-LABEL: func @testMul func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) { %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16> @@ -775,12 +789,30 @@ func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { // ----- func @testIfThen(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -func @testIfElse(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +func @testIfElse(tensor<2xf32>) -> tensor<2xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{branches should have 1 inputs}} + // expected-error @+1 {{'tf.If' op 'then_branch' inputs (size = 2) should have the same number of values as inputs (size = 1)}} + %1 = "tf.If"(%arg0, %arg1) { + then_branch = @testIfThen, + else_branch = @testIfElse, + is_stateless = false + } : (tensor, tensor<2xf32>) -> tensor<2xf32> + + return %1 : tensor<2xf32> +} + +// ----- + +func @testIfThen(tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) +func @testIfElse(tensor<2xf32>) -> tensor<2xf32> + +// Test invalid tf.If operation +func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { +^bb0(%arg0: tensor, %arg1: tensor<2xf32>): + // expected-error @+1 {{'tf.If' op 'then_branch' results (size = 2) should have the same number of values as results (size = 1)}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -798,7 +830,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}} + // expected-error @+1 {{'tf.If' op 'then_branch' input type tensor<*xf16> is incompatible with input type tensor<2xf32> at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -816,7 +848,7 @@ func @testIfElse(tensor<3xf32>) -> tensor<*xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): - // expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}} + // expected-error @+1 {{expects all branch input type(s) (tensor<2xf32>, tensor<3xf32>) at index 0 to be cast compatible}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -834,7 +866,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<3xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): - // expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}} + // expected-error @+1 {{'tf.If' op 'else_branch' result type tensor<3xf32> is incompatible with result type tensor<2xf32> at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -848,7 +880,7 @@ func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // Test invalid tf.Yield operation (parent should be IfRegion) func @testInvalidYieldOp(%arg0: f32) -> () { - // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.IfRegion, tf.WhileRegion'}} + // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}} "tf.Yield"(%arg0) : (f32) -> () } @@ -895,7 +927,7 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2x // Test invalid type for operand #0 for tf.IfRegion operation func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{operand #0 must be tensor of tf.dtype values}} + // expected-error @+1 {{operand #0 must be 0D tensor of 1-bit signless integer values, but got 'f32'}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -982,7 +1014,7 @@ func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> ten // tf.Region yield number of results should match op number of results func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op then should have same number (1) of results as tf.IfRegion but has 2 results}} + // expected-error @+1 {{'tf.IfRegion' op then results (size = 2) should have the same number of values as results (size = 1)}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () @@ -997,7 +1029,7 @@ func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te // ----- func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{tf.IfRegion' op else should have same number (1) of results as tf.IfRegion but has 2 results}} + // expected-error @+1 {{'tf.IfRegion' op else results (size = 2) should have the same number of values as results (size = 1)}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -1013,7 +1045,7 @@ func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te // tf.IfRegion yield types should match op result types func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{then result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + // expected-error @+1 {{'tf.IfRegion' op then result type tensor is incompatible with result type tensor<2xf32> at index 0}} %0 = "tf.IfRegion"(%arg0) ({ "tf.Yield"(%arg0) : (tensor) -> () }, { @@ -1027,7 +1059,7 @@ func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) - // ----- func @testIfRegionOpYieldMismatchElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{else result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + // expected-error @+1 {{'tf.IfRegion' op else result type tensor is incompatible with result type tensor<2xf32> at index 0}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -1509,7 +1541,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xi32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{operand type tensor<*xf32> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op input type tensor<*xf32> is incompatible with result type tensor<*xi32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1527,7 +1559,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{operand type tensor<*xf32> is incompatible with cond function input type}} + // expected-error @+1 {{'tf.While' op input type tensor<*xf32> is incompatible with condition input type tensor<*xi32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1545,7 +1577,7 @@ func @testWhileBody(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires the number of operands to be equal to the number of body function inputs. Found 1 and 2, respectively}} + // expected-error @+1 {{'tf.While' op inputs (size = 1) should have the same number of values as body inputs (size = 2)}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1563,7 +1595,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor<*xi32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{body function result type tensor<*xi32> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op body result type tensor<*xi32> is incompatible with result type tensor<*xf32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1581,7 +1613,7 @@ func @testWhileBody(tensor<4xf32>) -> (tensor<*xf32>) // Test invalid 'While' operation func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{cond function input type tensor<3xf32> is incompatible with body function input type}} + // expected-error @+1 {{'tf.While' op condition input type tensor<3xf32> is incompatible with body input type tensor<4xf32> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1600,7 +1632,7 @@ func @testWhileBody(tensor<*x!tf.resource>>) -> (tensor>>) -> (tensor>>) { ^bb0(%arg0: tensor<*x!tf.resource>>): - // expected-error @+1 {{operand type tensor<*x!tf.resource>> is incompatible with result type}} + // expected-error @+1 {{'tf.While' op input type tensor<*x!tf.resource>> is incompatible with result type tensor>> at index 0}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, @@ -1696,48 +1728,71 @@ func @testValidWhileRegionNoInputs() -> () { } // ----- +// Invalid while tests. There are 5 sets of type matching that is required +// I = input, O = output, BI, BO = body input/output, CI = cond input. +// [I, O], [I, CI], [I, BI], [BO, BI], [BO, O]. +// Each check can fail due to number or type mismatch. However, these +// conditions are not all independent. So we just check I->{CI, BI}, O->BO, and +// in addition I->O. BO->BI mismatch cannot be independently created without +// breaking one of these mismatches. That gives us 4x2 tests. In addition +// condition result needs to be tensor, for which we have 3 +// additional validation tests. All these tests are based on the following +// valid while -func @testInvalidWhileRegionMismatchCondInputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op condition should have same number of inputs (1) as tf.WhileRegion but has 0 inputs}} - %0 = "tf.WhileRegion"(%arg) ( - { - // ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - "tf.Yield"(%arg) : (tensor) -> () - } - ) : (tensor) -> (tensor) +func @testInvalidTestValidBase(%arg0 : tensor) -> (tensor) { + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} +func @testInvalidWhileRegion_I_CI_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as condition inputs (size = 0)}} + %0 = "tf.WhileRegion"(%arg0) ( + { + //^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) return %0 : tensor } // ----- -func @testInvalidWhileRegionMismatchCondInputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op condition input type tensor is incompatible with tf.WhileRegion input type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - "tf.Yield"(%barg) : (tensor) -> () - } - ) : (tensor) -> (tensor) - +func @testInvalidWhileRegion_I_CI_TypeMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op input type tensor is incompatible with condition input type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) return %0 : tensor } // ----- -func @testInvalidWhileRegionMismatchBodyInputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op body should have same number of inputs (1) as tf.WhileRegion but has 2 inputs}} - %0 = "tf.WhileRegion"(%arg) ( +func @testInvalidWhileRegion_I_BI_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as body inputs (size = 2)}} + %0 = "tf.WhileRegion"(%arg0) ( { ^bb0(%carg: tensor): %true = constant dense<1> : tensor @@ -1754,9 +1809,9 @@ func @testInvalidWhileRegionMismatchBodyInputCount(%arg : tensor) -> (tenso // ----- -func @testInvalidWhileRegionMismatchBodyInputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{body input type tensor is incompatible with tf.WhileRegion input type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( +func @testInvalidWhileRegion_I_BI_TypeMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op input type tensor is incompatible with body input type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( { ^bb0(%carg: tensor): %true = constant dense<1> : tensor @@ -1774,6 +1829,77 @@ func @testInvalidWhileRegionMismatchBodyInputType(%arg : tensor) -> (tensor // ----- +func @testInvalidWhileRegion_O_BO_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op body results (size = 2) should have the same number of values as results (size = 1)}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg, %barg) : (tensor, tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0#0 : tensor +} + +// ----- + +func @testInvalidWhileRegionMismatch_O_BO_TypeMismatch(%arg0 : tensor, %arg1: tensor) -> (tensor) { + // expected-error @+1 {{'tf.WhileRegion' op body result type tensor is incompatible with result type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%arg1) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} + +// ----- + +func @testInvalidWhileRegion_I_O_CountMismatch(%arg0 : tensor) -> (tensor) { + // expected-error@+1 {{'tf.WhileRegion' op inputs (size = 1) should have the same number of values as results (size = 2)}} + %0:2 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%barg, %barg) : (tensor, tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor, tensor) + return %0#0 : tensor +} + +// ----- + +func @testInvalidWhileRegion_I_O_TypeMismatch(%arg0: tensor, %arg1 : tensor) -> (tensor) { + // expected-error@+1 {{'tf.WhileRegion' op input type tensor is incompatible with result type tensor at index 0}} + %0 = "tf.WhileRegion"(%arg0) ( + { + ^bb0(%carg: tensor): + %false = constant dense : tensor + "tf.Yield"(%false) : (tensor) -> () + }, + { + ^bb0(%barg: tensor): + "tf.Yield"(%arg1) : (tensor) -> () + } + ) { is_stateless = true } : (tensor) -> (tensor) + return %0 : tensor +} +// ----- + func @testInvalidWhileRegionConditionOutputCount2(%arg : tensor) -> (tensor) { // expected-error @+1 {{'tf.WhileRegion' op condition should have a single tensor result}} %0 = "tf.WhileRegion"(%arg) ( @@ -1827,45 +1953,6 @@ func @testInvalidWhileRegionConditionOutputType(%arg : tensor) -> (tensor } -// ----- - -func @testInvalidWhileRegionMismatchBodyOutputCount(%arg : tensor) -> (tensor) { - // expected-error @+1 {{'tf.WhileRegion' op body should have same number (1) of results as tf.WhileRegion but has 2 results}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - %false = constant dense<1> : tensor - "tf.Yield"(%barg, %false) : (tensor, tensor) -> () - } - ) : (tensor) -> (tensor) - - return %0 : tensor -} - -// ----- - -func @testInvalidWhileRegionMismatchBodyOutputType(%arg : tensor) -> (tensor) { - // expected-error @+1 {{body result type tensor is incompatible with tf.WhileRegion result type tensor at index 0}} - %0 = "tf.WhileRegion"(%arg) ( - { - ^bb0(%carg: tensor): - %true = constant dense<1> : tensor - "tf.Yield"(%true) : (tensor) -> () - }, - { - ^bb0(%barg: tensor): - %c = "tf.Cast"(%barg) : (tensor) -> tensor - "tf.Yield"(%c) : (tensor) -> () - } - ) : (tensor) -> (tensor) - - return %0 : tensor -} // ----- @@ -2033,6 +2120,15 @@ func @testConst() -> tensor { // ----- +// Test invalid tf.ToBool +func @testInvalidToBool(%arg0: tensor) -> tensor<1xi1> { + // expected-error @+1 {{op result #0 must be 0D tensor of 1-bit signless integer values, but got 'tensor<1xi1>'}} + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor<1xi1> + return %0 : tensor<1xi1> +} + +// ----- + // Test valid tf.Transpose // CHECK-LABEL: testTranspose func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { @@ -2354,6 +2450,25 @@ func @testSlice_unknown_begin_in_bounds(%arg0: tensor<4xi32>, %begins: tensor<1x // ----- +func @testSlice_unequal_output_input_rank(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor { + %sizes = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + // expected-error @+1 {{requires output to have the same rank as input, but got input rank 1 and output rank 0}} + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSlice_wrong_output_size(%arg0: tensor<4xi32>) -> tensor<1xi32> { + %begins = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) + // expected-error @+1 {{requires output size to have the same size of slice, got slice size 2 and output size 1}} + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- + // Valid StridedSlice operation. func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor { %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor @@ -3138,6 +3253,125 @@ func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor) { // ----- +// CHECK-LABEL: func @testBatchMatMulV2NoBatchDimension +func @testBatchMatMulV2NoBatchDimension(%lhs: tensor<5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<5x10xf32>, tensor<10x10xf32>) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidBroadcastingBatchDimension +func @testBatchMatMulV2ValidBroadcastingBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<10x2x5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x10xf32> + return %0 : tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidMultiBatchDimension +func @testBatchMatMulV2ValidMultiBatchDimension(%lhs: tensor<4x5x1x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x2x5xf32> + return %0 : tensor<4x5x1x2x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherXRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithSameRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherYRank(%lhs: tensor<2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<2x10x10xf32>) { + // expected-error @+1 {{has mismatching input batch dimension 2 and output batch dimension 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<2x10x10xf32>) -> tensor<10x3x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x1x10x10xf32>) { + // expected-error @+1 {{found invalid output rank, expected 4 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x1x10x10xf32>) -> tensor<10x5x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRowDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidOutputRowDim(%lhs: tensor<10x2x10x5xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<10x2x10x5xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 10 but got 5}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x5xf32> +} + +// ----- + +func @testBatchMatMulV2AdjYInvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<4x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 4 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_y = true } : (tensor<10x2x5x10xf32>, tensor<4x10xf32>) -> tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownInputBatchDim +func @testBatchMatMulV2PartiallyKnownInputBatchDim(%lhs: tensor<4x5x?x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x?x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x?x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x?x2x5xf32> + return %0 : tensor<4x5x?x2x5xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownMatmulDim +func @testBatchMatMulV2PartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x5xf32> + return %0 : tensor<4x5x1x?x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x3x?xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x?xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + func @testDataFormatVecPermuteInvalid1dInput(%x: tensor<5xi32>) { // expected-error @+1 {{requires 1D input of size 4}} %0 = "tf.DataFormatVecPermute"(%x): (tensor<5xi32>) -> tensor<5xi32> @@ -3313,3 +3547,220 @@ func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: ten %0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32> return } + +// ----- + +func @branch() + +func @testCaseBadBranchIndicesShape(%arg0: tensor<8xi32>) { + // expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}} + "tf.Case"(%arg0) {branches = [@branch], is_stateless = false} : (tensor<8xi32>) -> () + return +} + +// ----- + +func @branch0(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +func @branch1(tensor<2xf32>) -> tensor<2xf32> + +func @testCaseMismatchedNumOperands(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{'tf.Case' op branch #0 inputs (size = 2) should have the same number of values as inputs (size = 1)}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) +func @branch1(tensor<2xf32>) -> tensor<2xf32> + +func @testCaseMismatchedNumResults(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{'tf.Case' op branch #0 results (size = 2) should have the same number of values as results (size = 1)}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<*xf16>) -> tensor<*xf32> +func @branch1(tensor<*xf32>) -> tensor<*xf32> + +func @testCaseOperandNotCastCompatible(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{'tf.Case' op branch #0 input type tensor<*xf16> is incompatible with input type tensor<2xf32> at index 0}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<2xf32>) -> tensor<*xf32> +func @branch1(tensor<3xf32>) -> tensor<*xf32> + +func @testCaseBranchArgumentsNotCastCompatible(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects all branch input type(s) (tensor<2xf32>, tensor<3xf32>) at index 0 to be cast compatible}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<*xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<*xf32>) -> tensor<*xf32> +func @branch1(tensor<*xf32>) -> tensor<3xf32> + +func @testCaseResultNotCastCompatible(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<2xf32> { + // expected-error @+1 {{'tf.Case' op branch #1 result type tensor<3xf32> is incompatible with result type tensor<2xf32> at index 0}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<*xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @testCaseRegionNoRegions(%arg0: tensor) { + // expected-error @+1 {{expects to have at least 1 region}} + "tf.CaseRegion"(%arg0) {is_stateless = false} : (tensor) -> () + return +} + +// ----- + +func @testCaseRegionBadBranchIndicesShape(%arg0: tensor<8xi32>) { + // expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}} + "tf.CaseRegion"(%arg0) ( { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor<8xi32>) -> () + return +} + +// ----- + +func @testCaseRegionMismatchedNumResults(%arg0: tensor) { + // expected-error @+1 {{'tf.CaseRegion' op branch #0 results (size = 0) should have the same number of values as results (size = 1)}} + %1 = "tf.CaseRegion"(%arg0) ( { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> tensor + return +} + +// ----- + +func @testCaseRegionMismatchedResultTypes(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{'tf.CaseRegion' op branch #0 result type tensor is incompatible with result type tensor at index 0}} + %1 = "tf.CaseRegion"(%arg0) ( { + "tf.Yield"(%arg1) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + return +} + +// ----- + +// Test valid tf.Cumsum +func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor<8x16xf32> { + %0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> { + // expected-error @+1 {{requires scalar axis operand}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { + %axis = constant dense<-3> : tensor + // expected-error @+1 {{axis operand should be within range [-2, 2)}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testTile(%arg0: tensor<2x3x?xf32>) { + %cst = constant dense <[2, 3, 4]> : tensor<3xi32> + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32> + return +} + +// ----- + +func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) { + // expected-error @+1 {{expected multiples to be rank 1, got rank = 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) { + // expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) { + // expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32> + return +} + +// ----- + +func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[-1, 1]> : tensor<2xi32> + // expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[2, 3]> : tensor<2xi32> + // expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32> + return +} + +// ----- + +// Test reference variable support for some ops (no errors expected) + +// CHECK-LABEL: @testMaximumWithRef +func @testMaximumWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.Maximum + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testAddV2WithRef +func @testAddV2WithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.AddV2 + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testRealDivWithRef +func @testRealDivWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.RealDivOp + %0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testDivNoNanWithRef +func @testDivNoNanWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.DivNoNanOp + %0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testAddWithRef +func @testAddWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.Add + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir index 7fc2b210f91..11ceac1fe99 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir @@ -9,17 +9,17 @@ func @select(%arg0: tensor, %arg1: tensor) -> (tensor, tensor tensor %1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor - %4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], is_stateless = false} : (tensor, tensor, tensor) -> tensor return %0, %4 : tensor, tensor } -func @add(%i: tensor, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -func @sub(%i: tensor, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 1e537880620..23a8e904ad9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<* %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<* %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32> - tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> + tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32> } return %0 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir new file mode 100644 index 00000000000..6399d7d6fb0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir @@ -0,0 +1,24 @@ +// RUN: tf-opt %s -tf-tpu-cleanup-cluster-attributes | FileCheck %s + +func @test(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "tf_device.cluster" + // CHECK-NOT: _tpu_replicate = + // CHECK-NOT: device = + %1 = "tf_device.cluster"() ( { + %2 = "tf.Add"(%arg1, %arg1) : (tensor, tensor) -> tensor + %3 = "tf.IfRegion"(%arg0) ({ + %4 = "tf.Mul" (%arg1, %2) {device = "y"}: (tensor, tensor) -> tensor + "tf.Yield"(%4) : (tensor) -> () + }, { + %5 = "tf.Div" (%arg1, %2) : (tensor, tensor) -> tensor + "tf.Yield"(%5) : (tensor) -> () + }) {is_stateless = true, _tpu_replicate = "x" } : (tensor) -> (tensor) + tf_device.return %3 : tensor + // CHECK: {_tpu_replicate = "x", cluster_attr = "cluster_attr", device = "y"} + }) {cluster_attr = "cluster_attr", _tpu_replicate = "x", device = "y"} : () -> tensor + // CHECK: "tf.Add" + // CHECK-SAME: {_tpu_replicate = "x", device = "y"} + %2 = "tf.Add"(%arg2, %1) {_tpu_replicate = "x", device = "y"} : (tensor, tensor) -> tensor + // CHECK: return + return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir index 9467f890419..7b670cd831c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir @@ -11,9 +11,9 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"} NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false} // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" @@ -31,7 +31,7 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"} // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -49,9 +49,9 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device: NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} @@ -62,13 +62,13 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device: }) {device = "/device:CPU:0"} : () -> () %execute0 = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor %4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) %execute1 = "tf_device.launch"() ( { %5 = "tf.TPUExecute"(%4#0, %4#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %5 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute1 : tensor @@ -85,9 +85,9 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) - NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"} @@ -98,7 +98,7 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) - }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -116,9 +116,9 @@ func @arg_on_tpu_iter_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} @@ -129,7 +129,7 @@ func @arg_on_tpu_iter_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -148,9 +148,9 @@ func @arg_on_tpu_intermediate_ops_on_cpu(%arg0: tensor<*x!tf.resource> {tf.devic NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) %id1 = "tf.Identity"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>) %id2 = "tf.Identity"(%id1) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>) // CHECK-NOT: "tf.TPUGetLayoutOp" @@ -163,7 +163,7 @@ func @arg_on_tpu_intermediate_ops_on_cpu(%arg0: tensor<*x!tf.resource> {tf.devic }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -181,9 +181,9 @@ func @var_handle_on_tpu_iter_on_cpu() -> tensor { NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource> // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" @@ -195,7 +195,7 @@ func @var_handle_on_tpu_iter_on_cpu() -> tensor { }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -212,9 +212,9 @@ func @unsupported_ops(%arg0: tensor<3x3x1x32xf32> {tf.device = "/device:CPU:0"}) NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32> @@ -224,7 +224,7 @@ func @unsupported_ops(%arg0: tensor<3x3x1x32xf32> {tf.device = "/device:CPU:0"}) }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%arg0, %2, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -246,9 +246,9 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false} // CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext" @@ -267,7 +267,7 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} { // CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1) %execute = "tf_device.launch"() ( { - %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %4 : tensor }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor tf_device.return %execute : tensor @@ -286,9 +286,9 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU: NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" "tf_device.launch"() ( { @@ -300,7 +300,7 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU: %2:2 = "tf.IteratorGetNext"(%r0) : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) %execute = "tf_device.launch"() ( { - %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %4 : tensor }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor tf_device.return %execute : tensor @@ -330,9 +330,9 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() %compile:3 = "tf_device.launch"() ( { - %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) - tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor<2x!tf.string>, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" @@ -351,7 +351,7 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/device:TPU:0" "tf_device.launch"() ( { - "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "/device:TPU:0"} : () -> () tf_device.return @@ -364,7 +364,7 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/device:TPU:1" "tf_device.launch"() ( { - "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "/device:TPU:1"} : () -> () tf_device.return @@ -396,9 +396,9 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() %compile:3 = "tf_device.launch"() ( { - %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) - tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor<2x!tf.string>, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"(%[[ARG0]]) @@ -423,7 +423,7 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" "tf_device.launch"() ( { - "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -433,7 +433,7 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" "tf_device.launch"() ( { - "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_1"} : () -> () tf_device.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir new file mode 100644 index 00000000000..a505a4e3269 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir @@ -0,0 +1,64 @@ +// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @write_only_resource +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor<*x!tf.resource>>) +func @write_only_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]]) + // CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]]) + // CHECK-SAME: _tpu_replicate = "write" + %0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor, tensor) -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1) + "tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @read_write_resource +func @read_write_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-COUNT-1: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource>>) -> tensor + %1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor, tensor, tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @read_write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @read_write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_write_resource +func @multiple_write_resource(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) { + // CHECK-NOT: tf.ReadVariableOp + %0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @multiple_write_func +// CHECK-SAME: ({{%.*}}: tensor) -> (tensor, tensor) +func @multiple_write_func(%arg0: tensor) -> (tensor, tensor) { + return %arg0, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_result_user +func @multiple_result_user(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) -> tensor { + // CHECK-NOT: tf.ReadVariableOp + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource>>, tensor) -> () + return %0 : tensor +} + +// CHECK-LABEL: func @multiple_result_user_func +// CHECK-SAME: ({{%.*}}: tensor) -> tensor +func @multiple_result_user_func(%arg0: tensor) -> tensor { + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index 1e308b42bfc..277e4a8415e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -61,9 +61,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -86,7 +86,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -153,9 +153,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -173,7 +173,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor) -> () + tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -239,9 +239,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -254,7 +254,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -342,9 +342,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -367,7 +367,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 37dfec5e6df..281e4baaa12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -331,6 +331,155 @@ func @mirrored_variables(%arg0: tensor>>, %arg1: ten // CHECK-SAME: _replicated_input_indices = [0, 1, 2] +// Test resource usage after resource use in cluster is moved to after the +// cluster. +// CHECK-LABEL: func @resource_after_cluster +// CHECK-SAME: ([[USED_RESOURCE:%.*]]: tensor<*x!tf.resource>>, [[UNUSED_RESOURCE:%.*]]: tensor<*x!tf.resource>>) +func @resource_after_cluster(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) { + // CHECK-NEXT: [[CONST:%.*]] = "tf.Const" + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + + // CHECK-NEXT: "tf.AssignSubVariableOp"([[UNUSED_RESOURCE]], [[CONST]]) + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.ReadVariableOp"([[USED_RESOURCE]]) + // CHECK-NEXT: "tf.NoOp" + // CHECK-NEXT: tf_device.return + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster_test_fn", allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> () + %1 = "tf.ReadVariableOp"(%arg0) {_tpu_replicate = "cluster_test_fn"} : (tensor<*x!tf.resource>>) -> tensor + + "tf.AssignSubVariableOp"(%arg1, %0) : (tensor<*x!tf.resource>>, tensor) -> () + + // CHECK: "tf.AssignAddVariableOp"([[USED_RESOURCE]], [[CONST]]) + "tf.AssignAddVariableOp"(%arg0, %0) : (tensor<*x!tf.resource>>, tensor) -> () + + "tf.NoOp"() {_tpu_replicate = "cluster_test_fn"} : () -> () + return +} + + +// Test resource not used by cluster is moved to before the cluster. +// CHECK-LABEL: func @resource_before_cluster +func @resource_before_cluster() { + // CHECK-NEXT: [[CONST:%.*]] = "tf.Const" + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + + // CHECK-NEXT: [[UNUSED_RESOURCE:%.*]] = "tf.VarHandleOp" + // CHECK-NEXT: "tf.AssignAddVariableOp"([[UNUSED_RESOURCE]], [[CONST]]) + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.NoOp" + // CHECK-NEXT: tf_device.return + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster_test_fn", allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> () + + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> tensor<*x!tf.resource>> + "tf.AssignAddVariableOp"(%1, %0) : (tensor<*x!tf.resource>>, tensor) -> () + + "tf.NoOp"() {_tpu_replicate = "cluster_test_fn"} : () -> () + return +} + + +// Test cluster formation with ops with attached regions within a cluster. +// Nested op's that are moved should get their _tpu_replicate and device +// attributes cleared. +// CHECK-LABEL: func @cluster_ops_with_regions +func @cluster_ops_with_regions() { + %0 = "tf.opA"() ({ + %1 = "tf.opB"() {_tpu_replicate = "replicate", device = "device", name = "nameB"} : () -> (tensor) + }) {_tpu_replicate = "replicate", device = "device", name = "nameA"} : () -> tensor + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: "tf.opA"() ( { +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK-SAME: name = "nameB" +// CHECK: }) +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK: name = "nameA" +// CHECK: tf_device.return + +// A nested cluster op using result of another cluster op. In the below, opA and +// opB go in a cluster, and opD stays outside. +// CHECK-LABEL: func @cluster_nested_op_using_other_op +func @cluster_nested_op_using_other_op() { + %0 = "tf.opA"() { _tpu_replicate = "foo" } : () -> tensor + "tf.opB"() ({ + "tf.opC"(%0) : (tensor) -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.opD"(%0) : (tensor) -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ( { +// CHECK: [[OPA:%.*]] = "tf.opA"() : () -> tensor +// CHECK: "tf.opB"() ( { +// CHECK: "tf.opC"([[OPA]]) +// CHECK: tf_device.return [[OPA]] +// CHECK: "tf.opD"([[CLUSTER]]) + +// Preceding user is using resource updated by a nested op. +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @cluster_nested_op_updating_resource +func @cluster_nested_op_updating_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: }) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + +// Preceding user is using resource updated by the cluster within a nested op. +// Resource is updated by a cluster op, and opA (not in cluster) is using the +// resource in a nested op. We expect opA to be after the cluster. +// CHECK-LABEL: func @cluster_nested_op_using_resource +func @cluster_nested_op_using_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + "tf.AssignAddVariableOp"(%1, %0) { _tpu_replicate = "foo" } : (!tf_res, tensor) -> () + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) : () -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + // ----- @@ -358,18 +507,6 @@ func @bad_num_replicas() { // ----- -// Test that functions without TPUReplicateMetadata op are skipped without -// error -// CHECK-LABEL: func @missing_metadata_op -func @missing_metadata_op() { - // expected-warning@+1 {{TPUReplicateMetadata for associated '_tpu_replicate' attribute 'replicate' is missing}} - %0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor - return -} - -// ----- - - // Test cluster with TPUReplicatedInput where the number of operands does not // match associated `num_replicas` attribute. func @mismatched_replicated_input(%arg0: tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir new file mode 100644 index 00000000000..88af4535d81 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir @@ -0,0 +1,118 @@ +// RUN: tf-opt %s -tf-tpu-colocate-composite-resource-ops | FileCheck %s + +// Tests ReadVariable op using composite device resource is wrapped inside +// tf_device.Cluster. + +// CHECK-LABEL: func @testReadVariableOpColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testReadVariableOpColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[RESOURCE_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[RI_0]]) + // CHECK-NEXT: tf_device.return %[[READ_OUT]] + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %1 = "tf.A"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %1) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + "tf_device.launch"() ( { + // CHECK: "tf.B"(%[[RESOURCE_OUT]]) + "tf.B"(%0) : (tensor<4xf32>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// CHECK-LABEL: func @testReadVariableOpAfterIdentityColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testReadVariableOpAfterIdentityColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[IDENTITY_OUT:.*]] = "tf.Identity"(%[[RI_0]]) + // CHECK: %[[RESOURCE_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[IDENTITY_OUT]]) + // CHECK-NEXT: tf_device.return %[[READ_OUT]] + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %2 = "tf.A"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + "tf_device.launch"() ( { + // CHECK: "tf.B"(%[[RESOURCE_OUT]]) + "tf.B"(%1) : (tensor<4xf32>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// Tests AssignVariable op using composite device resource is wrapped inside +// tf_device.Cluster. + +// CHECK-LABEL: func @testAssignVariableOpColocated +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testAssignVariableOpColocated(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}, + n = 2 : i32} { + // CHECK: %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32> + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: TPU_REPLICATED_CORE_0 + %1 = "tf.A"() : () -> (tensor<4xf32>) + "tf.AssignVariableOp"(%arg1, %1) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + %2 = "tf.B"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + } + return +} + +// Tests tf_device.replicate op not running on TPU devices ignored. + +// CHECK-LABEL: func @testNonTPUDeviceReplicationIgnored +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*x!tf.resource>>) +func @testNonTPUDeviceReplicationIgnored(%arg0: tensor<*x!tf.resource>>) { + // CHECK: tf_device.replicate + // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf.resource>>) + tf_device.replicate(%arg0 as %arg1: tensor<*x!tf.resource>>) { + _mirrored_variable_indices = [0], _replicated_input_indices = [-1], + devices = {TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1"]}, + n = 2 : i32} { + // CHECK: %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]]) + %1 = "tf.A"() : () -> (tensor<4xf32>) + "tf.AssignVariableOp"(%arg1, %1) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + %2 = "tf.B"() : () -> (tensor<2x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () + tf_device.return + }) {device = "TPU_REPLICATED_HOST"} : () -> () + tf_device.return + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 32a8000ea82..8ae6fa958a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op() { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op_user() -> tensor { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster = "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return %b : tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor // CHECK: return %[[LAUNCH_OUT]] @@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" - // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const" // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster:5 = "tf_device.cluster"() ( { %c = "tf.C"() : () -> tensor %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor - %e = "tf.E"() : () -> tensor + %e = "tf.Const"() {value = dense<0> : tensor} : () -> tensor tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 @@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { // CHECK-NOT: "tf_device.launch" // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: tf_device.return "tf_device.cluster"() ( { - %a = "tf.A"(%arg0) : (tensor) -> tensor + %a = "tf.Identity"(%arg0) : (tensor) -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"(%b) : (tensor) -> () + %c = "tf.Identity"(%b) : (tensor) -> tensor tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) - // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) // CHECK-NEXT: tf_device.return %[[C_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -399,11 +399,139 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor - %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + %e:2 = "tf.IdentityN"(%c, %a) : (tensor, tensor) -> (tensor, tensor) tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () tf_device.return } return } + + // CHECK-LABEL: func @side_effect_middle + func @side_effect_middle() { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_head_no_operand + func @side_effect_head_no_operand() { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + + "tf_device.cluster"() ( { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + %c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.D"(%c) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_tail_no_operand + func @side_effect_tail_no_operand() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test embedding ops can be head extracted and side effect analysis + // predecessors are ignored. + + // CHECK-LABEL: func @embedding_head_extraction + func @embedding_head_extraction(%arg0: tensor) { + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.UnknownOp" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.UnknownOp"() : () -> () + "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {_xla_outside_compilation = "cluster1", table_ids = [1, 2]} : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test side effecting op after embedding op can be head extracted. + + // CHECK-LABEL: func @op_after_embedding_head_extraction + func @op_after_embedding_head_extraction() { + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.RecvTPUEmbeddingActivations" + // CHECK-NEXT: "tf.SendTPUEmbeddingGradients" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32> + "tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> () + "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // Test side effecting op before embedding op can be tail extracted. + + // CHECK-LABEL: func @op_before_embedding_tail_extraction + func @op_before_embedding_tail_extraction() { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.UnknownOp" + // CHECK-NEXT: "tf.RecvTPUEmbeddingActivations" + // CHECK-NEXT: "tf.SendTPUEmbeddingGradients" + // CHECK-NEXT: tf_device.return + + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + "tf.UnknownOp"() : () -> () + "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () + %0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32> + "tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 732e34fce90..9b828e42844 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -456,4 +456,739 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } return %1 : tensor } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if + func @outside_compiled_ops_inside_tf_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled + + // CHECK-LABEL: func @outside_compiled_tf_if + func @outside_compiled_tf_if(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) + // CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled and wrapped inside another + // tf.IfRegion op + + // CHECK-LABEL: func @outside_compiled_tf_if_nested + func @outside_compiled_tf_if_nested(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: (tensor<2x!tf.string>) -> tensor + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]]) + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1) + // CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]]) + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: (tensor) -> () + // CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[D_OUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf.Yield"() : () -> () + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %8 = "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> (tensor) + %9 = "tf.F"(%4) {} : (tensor) -> (tensor) + + "tf.IfRegion"(%9) ({ + "tf.H"(%8, %7) : (tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion + // op with return values. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_with_return_values + func @outside_compiled_ops_inside_tf_if_with_return_values( + %arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %7 = "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + "tf.Yield"(%7) : (tensor) -> () + }, { + + %8 = "tf.F"() : () -> (tensor) + "tf.Yield"(%8) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> (tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op without external inputs/outputs + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_without_input_outputs + func @outside_compiled_ops_inside_tf_if_without_input_outputs( + %arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK: "tf.D" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a nested + // tf.IfRegion op. + + // CHECK-LABEL: func @outside_compiled_ops_inside_nested_if + func @outside_compiled_ops_inside_nested_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_1" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]]) + // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_1"} + // CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]]) + // CHECK: "tf._XlaHostComputeMlir"(%[[I_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %7 = "tf.H"(%4) : (tensor) -> (tensor) + + "tf.IfRegion"(%7)({ + "tf.Yield"() : () -> () + }, + { + %8 = "tf.I"(%7) : (tensor) -> (tensor) + "tf.D"(%8) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body + func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.Yield"(%8, %9) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond + func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.D" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body + func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster2_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op + // nested in a tf.WhileRegion. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if + func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]]) + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled with a nested tf.WhileRegion op. + + // CHECK-LABEL: func @outside_compiled_tf_if_nested_while + func @outside_compiled_tf_if_nested_while(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) + // CHECK-NEXT: %[[J_OUTPUT:[0-9]*]] = "tf.J" + // CHECK-NEXT: %[[K_OUTPUT:[0-9]*]] = "tf.K" + // CHECK-NEXT: tf.WhileRegion"(%[[J_OUTPUT]], %[[D_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[K_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %8 = "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> (tensor) + %9 = "tf.J"() : () -> (tensor) + %10 = "tf.K"() : () -> (tensor) + "tf.WhileRegion"(%9, %8) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %11 = "tf.I"(%arg1, %arg2) : (tensor, tensor) -> tensor + %12 = "tf.H"(%10) : (tensor) -> tensor + "tf.Yield"(%12) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %11 = "tf.C"(%arg1) : (tensor) -> tensor + %12 = "tf.D"(%arg1, %arg2) : (tensor, tensor) -> tensor + "tf.Yield"(%11, %12) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.WhileRegion where the entire + // tf.WhileREgion op is outside compiled with a nested tf.IfRegion. + + // CHECK-LABEL: func @outside_compiled_ops_tf_while_nested_if + func @outside_compiled_ops_tf_while_nested_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[HOST_RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: "tf.WhileRegion"(%[[HOST_RECV_OUTPUT]]#1, %[[HOST_RECV_OUTPUT]]#2) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK: "tf.IfRegion"(%[[HOST_RECV_OUTPUT]]#0) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]], %[[B_OUTPUT]], %[[A_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"(%8) : (tensor) -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir new file mode 100644 index 00000000000..317e7036c42 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir @@ -0,0 +1,93 @@ +// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always + +// Tests Identity op in cluster is pruned away. + +// CHECK-LABEL: func @testIdentity +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testIdentity(%arg0: tensor) { + // CHECK-NOT: "tf.Identity" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[ARG0]] + %0 = "tf_device.cluster"() ( { + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return +} + +// Tests IdentityN op in cluster is pruned away. + +// CHECK-LABEL: func @testIdentityN +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) +func @testIdentityN(%arg0: tensor, %arg1: tensor) { + // CHECK-NOT: "tf.IdentityN" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]] + %0:2 = "tf_device.cluster"() ( { + %1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + tf_device.return %1#0, %1#1 : tensor, tensor + }) : () -> (tensor, tensor) + return +} + +// Tests transitive Identity ops reachable from the cluster are pruned away. + +// CHECK-LABEL: func @testTransitiveIdentity +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testTransitiveIdentity(%arg0: tensor) { + // CHECK: "tf_device.cluster" + // CHECK: "tf.PartitionedCall"([[ARG0]]) + // CHECK-SAME: f = @callee0 + %0 = "tf_device.cluster"() ( { + %1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor) -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return +} + +// CHECK-LABEL: func @callee0 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee0(%arg0: tensor) -> tensor { + // CHECK-NOT: "tf.Identity" + // CHECK: "tf.PartitionedCall"([[ARG0]]) + // CHECK-SAME: f = @callee1 + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @callee1 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee1(%arg0: tensor) -> tensor { + // CHECK-NOT: "tf.Identity" + // CHECK: return [[ARG0]] + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// Tests Identity ops not reachable from the cluster are not pruned away. + +// CHECK-LABEL: func @testIdentityOutsideCluster +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testIdentityOutsideCluster(%arg0: tensor) { + // CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]]) + // CHECK: [[CLUSTER:%.*]] = "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[IDENTITY]] + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + %1 = "tf_device.cluster"() ( { + tf_device.return %0 : tensor + }) : () -> tensor + // CHECK: "tf.PartitionedCall"([[CLUSTER]]) + // CHECK-SAME: f = @callee2 + %2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor) -> tensor + return +} + +// CHECK-LABEL: func @callee2 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee2(%arg0: tensor) -> tensor { + // CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]]) + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return [[IDENTITY]] + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir index 1394bd22dc8..269af51504f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir @@ -75,7 +75,7 @@ func @two_clusters_no_dependencies() { // CHECK: "tf.opB" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4:[a-zA-Z_0-9]+]]" // CHECK: "tf.opC" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER5:[a-zA-Z_0-9]+]]" // CHECK: "tf.opD" "tf_device.cluster"() ( { "tf.opA"() : () -> () @@ -135,6 +135,27 @@ func @two_clusters_with_two_ops_each() { return } +// CHECK-LABEL: func @resource_side_effect_cycle +func @resource_side_effect_cycle(%arg0: tensor>>, %arg1: tensor>>) { + // CHECK: "tf.ReadVariableOp" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK-NEXT: "tf.AssignVariableOp" + // CHECK-NOT: {_xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %read0 = "tf.ReadVariableOp"(%arg0) {_xla_outside_compilation = "0"} : (tensor>>) -> tensor + %idet0 = "tf.Identity"(%read0) {_xla_outside_compilation = "0"} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg1, %idet0) : (tensor>>, tensor) -> () + %read1 = "tf.ReadVariableOp"(%arg1) {_xla_outside_compilation = "0"} : (tensor>>) -> tensor + %idet1 = "tf.Identity"(%read1) {_xla_outside_compilation = "0"} : (tensor) -> tensor + %add0 = "tf.AddV2"(%idet0, %idet1) {_xla_outside_compilation = "0"} : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%arg0, %add0) {_xla_outside_compilation = "0"} : (tensor>>, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + // CHECK-LABEL: func @two_clusters_with_same_parent func @two_clusters_with_same_parent() { // CHECK: "tf.opA" @@ -144,10 +165,10 @@ func @two_clusters_with_same_parent() { // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER10]]" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER10]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opE" // CHECK-NEXT: "tf.opF" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opG" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor @@ -171,8 +192,8 @@ func @two_clusters_with_same_outside_compiled_parent() { // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12]]" - // CHECK-NEXT: "tf.opE" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER14:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.opF" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13]]" // CHECK-NEXT: "tf.opG" @@ -182,7 +203,7 @@ func @two_clusters_with_same_outside_compiled_parent() { %b = "tf.opB"(%a) : (tensor) -> tensor %c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor) -> tensor %d = "tf.opD"() {_xla_outside_compilation = "0"} : () -> tensor - %e = "tf.opE"(%d) : (tensor) -> tensor + %e = "tf.Identity"(%d) : (tensor) -> tensor %f = "tf.opF"(%e) {_xla_outside_compilation = "0"} : (tensor) -> tensor %g = "tf.opG"(%c, %f) {_xla_outside_compilation = "0"} : (tensor, tensor) -> tensor tf_device.return @@ -213,14 +234,15 @@ func @outside_compile_with_block() { // CHECK-NEXT: "tf.opB" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]" // CHECK: "tf.opC" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER14]]" + // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor - %b = "tf.opB"() {_xla_outside_compilation = "0"} : () -> tensor + %b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor) -> tensor "tf_device.cluster" () ( { tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () - %c = "tf.opC"() {_xla_outside_compilation = "0"} : () -> tensor + %c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor) -> tensor tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () return @@ -248,3 +270,144 @@ func @two_clusters_with_one_op_each_with_indirect_dependency() { }) {cluster_attr = "cluster_attr"} : () -> () return } + +// CHECK-LABEL: func @check_ops_with_data_dependency_added_as_host_cluster +func @check_ops_with_data_dependency_added_as_host_cluster() { + // CHECK: "tf.opA" + // CHECK-NEXT: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER16:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Identity" + // CHECK-NEXT: "tf.Identity" + // CHECK-NEXT: "tf.opE" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER16]]" + // CHECK-NEXT: "tf.opF" + "tf_device.cluster"() ( { + %a = "tf.opA"() : () -> tensor + %b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor) -> tensor + %c = "tf.Identity"(%b) : (tensor) -> tensor + %d = "tf.Identity"(%c) : (tensor) -> tensor + %e = "tf.opE"(%d, %b, %c) {_xla_outside_compilation = "0"} : (tensor, tensor, tensor) -> tensor + "tf.opF"(%e) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_op_inside_nested_region_clustered +func @check_op_inside_nested_region_clustered(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() : () -> (tensor) + %3 = "tf.C"() : () -> (tensor) + %4 = "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto1", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_ops_inside_different_block_in_different_cluster +func @check_ops_inside_different_block_in_different_cluster(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.C" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER19:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.D" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER19]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) + %3 = "tf.C"() {_xla_outside_compilation = "auto2"} : () -> (tensor) + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto3", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto4", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto5", value = dense<"a"> : tensor} : () -> tensor + "tf.D"(%3, %4, %1) {_xla_outside_compilation = "auto6"} : (tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @check_clustering_ops_inside_nested_control_flow +func @check_clustering_ops_inside_nested_control_flow(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.B" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.C" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK: "tf.IfRegion" + // CHECK: "tf.IfRegion" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" + // CHECK-NEXT: "tf.Const" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-NEXT: "tf.WriteSummary" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) + %3 = "tf.C"() {_xla_outside_compilation = "auto2"} : () -> (tensor) + "tf.IfRegion"(%0) ( { + %6 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%6) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %4 = "tf.Const"() {_xla_outside_compilation = "auto3", value = dense<"logits"> : tensor} : () -> tensor + %5 = "tf.Const"() {_xla_outside_compilation = "auto4", value = dense<"\0A\09\0A\07scalars"> : tensor} : () -> tensor + "tf.WriteSummary"(%arg0, %2, %3, %4, %5) {_xla_outside_compilation = "auto2", device = "/device:CPU:0"} : (tensor<*x!tf.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + "tf.Yield"(%6) : (tensor) -> () + }, { + %7 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%7) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> tensor + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir new file mode 100644 index 00000000000..ad4433c1d20 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_parallel_execute_sink_resource_write.mlir @@ -0,0 +1,137 @@ +// RUN: tf-opt %s -tf-tpu-parallel-execute-sink-resource-write | FILECHECK_OPTS="" FileCheck %s + +// CHECK-LABEL: func @multiple_uses +// CHECK-SAME: ({{.+}}: tensor, [[ARG1:%.+]]: tensor) +func @multiple_uses(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG1]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%arg1, %0#0) : (tensor, tensor) -> () + // CHECK-NEXT: return [[PARALLEL_EXECUTE]]#0 + return %0#0 : tensor +} + +// CHECK-LABEL: func @not_assign_var +// CHECK-SAME: ({{.+}}: tensor, [[ARG1:%.+]]: tensor) +func @not_assign_var(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignAddVariableOp"([[ARG1]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignAddVariableOp"(%arg1, %0#0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_output +// CHECK-SAME: ([[ARG0:%.+]]: tensor, {{.+}}: tensor) +func @resource_handle_output(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg1 : tensor + }, { + tf_device.return %arg1 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[PARALLEL_EXECUTE]]#0, [[ARG0]]) + "tf.AssignVariableOp"(%0#0, %arg0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_and_value_output +func @resource_handle_and_value_output(%arg0: tensor, %arg1: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0, %arg1 : tensor, tensor + }, { + tf_device.return + }) : () -> (tensor, tensor) + // CHECK: "tf.AssignVariableOp"([[PARALLEL_EXECUTE]]#1, [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%0#1, %0#0) : (tensor, tensor) -> () + return +} + +// CHECK-LABEL: func @resource_handle_after_parallel_execute +func @resource_handle_after_parallel_execute(%arg0: tensor) { + // CHECK: [[PARALLEL_EXECUTE:%.+]]:2 = "tf_device.parallel_execute" + %0:2 = "tf_device.parallel_execute"() ( { + tf_device.return %arg0 : tensor + }, { + tf_device.return %arg0 : tensor + // CHECK: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor) + // CHECK-NEXT: [[VAR:%.+]] = "tf.VarHandleOp" + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> tensor>> + // CHECK-NEXT: "tf.AssignVariableOp"([[VAR]], [[PARALLEL_EXECUTE]]#0) + "tf.AssignVariableOp"(%1, %0#0) : (tensor>>, tensor) -> () + return +} + +// CHECK-LABEL: func @replace_single_output +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) +func @replace_single_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) { + // CHECK: {{%.+}}:2 = "tf_device.parallel_execute" + %0:3 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG3]], [[ARG1]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]] : tensor, tensor + tf_device.return %arg0, %arg1, %arg2 : tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: tf_device.return + tf_device.return + // CHECK-NEXT: }) : () -> (tensor, tensor) + }) : () -> (tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg3, %0#1) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @replace_multiple_outputs +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor, [[ARG4:%.+]]: tensor, [[ARG5:%.+]]: tensor, [[ARG6:%.+]]: tensor) +func @replace_multiple_outputs(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor) { + // CHECK: {{%.+}}:3 = "tf_device.parallel_execute" + %0:5 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG5]], [[ARG1]]) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG6]], [[ARG3]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]], [[ARG4]] : tensor, tensor, tensor + tf_device.return %arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: tf_device.return + tf_device.return + // CHECK-NEXT: }) : () -> (tensor, tensor, tensor) + }) : () -> (tensor, tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg5, %0#1) : (tensor, tensor) -> () + "tf.AssignVariableOp"(%arg6, %0#3) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @replace_multiple_outputs_regions +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor, [[ARG4:%.+]]: tensor, [[ARG5:%.+]]: tensor, [[ARG6:%.+]]: tensor, [[ARG7:%.+]]: tensor) +func @replace_multiple_outputs_regions(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) { + // CHECK: {{%.+}}:4 = "tf_device.parallel_execute" + %0:6 = "tf_device.parallel_execute"() ( { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG6]], [[ARG1]]) + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG2]] : tensor, tensor + tf_device.return %arg0, %arg1, %arg2 : tensor, tensor, tensor + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG7]], [[ARG4]]) + // CHECK-NEXT: tf_device.return [[ARG3]], [[ARG5]] : tensor, tensor + tf_device.return %arg3, %arg4, %arg5 : tensor, tensor, tensor + // CHECK-NEXT: }) : () -> (tensor, tensor, tensor, tensor) + }) : () -> (tensor, tensor, tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg6, %0#1) : (tensor, tensor) -> () + "tf.AssignVariableOp"(%arg7, %0#4) : (tensor, tensor) -> () + // CHECK-NEXT: return + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 2a0091ce9bf..ef7b52cd978 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1262,15 +1262,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1 %3 = "tf_device.parallel_execute"() ( { - %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor - "tf.D"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string> + "tf.D"(%program) : (tensor<2x!tf.string>) -> () tf_device.return }, { %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor tf_device.return %4 : tensor }, { - %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor - "tf.E"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string> + "tf.E"(%program) : (tensor<2x!tf.string>) -> () tf_device.return }) : () -> (tensor) tf_device.return %3 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir index 280986a7ee1..ceecb3e72d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -83,5 +83,80 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" } } -// ---- +// ----- + +// Tests for space to depth host and device transform with replicate inputs. + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} { + func @main(%arg0: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} { + // CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext" + // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>) + // CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext" + // CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>) + tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource>>, %arg8 as %arg16: tensor<*x!tf.resource>>, %arg7 as %arg17: tensor<*x!tf.resource>>, %arg9 as %arg18: tensor<*x!tf.resource>>, %arg10 as %arg19: tensor<*x!tf.resource>>, %arg11 as %arg20: tensor<*x!tf.resource>>, %arg12 as %arg21: tensor<*x!tf.resource>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {}, n = 2 : i32} { + %2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource>>) -> tensor<7x7x3x64xf32> + %3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource>>) -> tensor<1001xf32> + %4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource>>) -> tensor<64x1001xf32> + %5 = "tf.ReadVariableOp"(%arg18) : (tensor<*x!tf.resource>>) -> tensor + %6 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf.resource>>) -> tensor + %7 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf.resource>>) -> tensor + %8 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf.resource>>) -> tensor + %9:4 = "tf_device.cluster_func"(%arg13, %arg14, %2, %4, %3, %5, %6, %7, %8) {_tpu_replicate = "cluster_eval_step", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<2x1xf32>, tensor<7x7x3x64xf32>, tensor<64x1001xf32>, tensor<1001xf32>, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg18, %9#0) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg19, %9#1) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg20, %9#2) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg21, %9#3) : (tensor<*x!tf.resource>>, tensor) -> () + tf_device.return + } + return + } + // CHECK-LABEL: func @_func + // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32> + %4 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %5 = "tf.Const"() {value = dense<2.500000e-01> : tensor} : () -> tensor + %6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<[-1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> + %8 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %9 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %10 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + %11 = "tf.Pad"(%arg0, %10) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32> + %12 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x1xf32>) -> tensor<2x1xi64> + %13 = "tf.Reshape"(%12, %9) : (tensor<2x1xi64>, tensor<1xi32>) -> tensor<2xi64> + %14 = "tf.Squeeze"(%arg1) {squeeze_dims = [-1]} : (tensor<2x1xf32>) -> tensor<2xf32> + // CHECK: "tf.Conv2D" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32> + %15 = "tf.Conv2D"(%11, %arg2) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32> + %16 = "tf.Mean"(%15, %8) {keep_dims = false} : (tensor<2x112x112x64xf32>, tensor<2xi32>) -> tensor<2x64xf32> + %17 = "tf.MatMul"(%16, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x64xf32>, tensor<64x1001xf32>) -> tensor<2x1001xf32> + %18 = "tf.BiasAdd"(%17, %arg4) {data_format = "NHWC"} : (tensor<2x1001xf32>, tensor<1001xf32>) -> tensor<2x1001xf32> + %19 = "tf.Reshape"(%18, %7) : (tensor<2x1001xf32>, tensor<2xi32>) -> tensor<2x1001xf32> + %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%19, %13) : (tensor<2x1001xf32>, tensor<2xi64>) -> (tensor<2xf32>, tensor<2x1001xf32>) + %20 = "tf.Sum"(%loss, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %21 = "tf.Mul"(%20, %5) : (tensor, tensor) -> tensor + %22 = "tf.Sum"(%21, %4) {keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %23 = "tf.CrossReplicaSum"(%22, %3) : (tensor, tensor<1x2xi32>) -> tensor + %24 = "tf.Softmax"(%18) : (tensor<2x1001xf32>) -> tensor<2x1001xf32> + %25 = "tf.ArgMax"(%24, %2) : (tensor<2x1001xf32>, tensor) -> tensor<2xi64> + %26 = "tf.Cast"(%25) {Truncate = false} : (tensor<2xi64>) -> tensor<2xf32> + %27 = "tf.Equal"(%14, %26) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + %28 = "tf.Cast"(%27) {Truncate = false} : (tensor<2xi1>) -> tensor<2xf32> + %29 = "tf.Sum"(%28, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %30 = "tf.CrossReplicaSum"(%29, %3) : (tensor, tensor<1x2xi32>) -> tensor + %31 = "tf.AddV2"(%arg5, %23) : (tensor, tensor) -> tensor + %32 = "tf.CrossReplicaSum"(%1, %3) : (tensor, tensor<1x2xi32>) -> tensor + %33 = "tf.AddV2"(%arg6, %32) : (tensor, tensor) -> tensor + %34 = "tf.AddV2"(%arg7, %30) : (tensor, tensor) -> tensor + %35 = "tf.CrossReplicaSum"(%0, %3) : (tensor, tensor<1x2xi32>) -> tensor + %36 = "tf.AddV2"(%arg8, %35) : (tensor, tensor) -> tensor + return %31, %33, %34, %36 : tensor, tensor, tensor, tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir new file mode 100644 index 00000000000..8cc8d273bec --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir @@ -0,0 +1,91 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s + +// Test simple operations with no regions and no interrupts. They should be +// visited with stage "before all regions". + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{4: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{3: before all regions}} + return %0 : tensor +} + +// ----- + +// Test simple operations with no regions and interrupts. No remarks after +// the interrupting operation is visited. + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{2: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + %0 = "tf.Identity"(%arg0) {interrupt_before_all = true} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test operation with non empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) {interrupt_after_all = true} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test operation with multiple regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + "tf.yield"(%1) : (tensor) -> () + }) {interrupt_after_region = 0} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test static filtering +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{7: walk was interrupted}} +func @foo(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: before all regions}} + // expected-remark@below {{9: before region #1}} + // expected-remark@below {{10: after all regions}} + %0 = "tf.IfRegion"(%arg1) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + "tf.Yield"(%1) { interrupt_after_all = true } : (tensor) -> () + }) {is_stateless = true}: (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir new file mode 100644 index 00000000000..9a832b7fe8d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir @@ -0,0 +1,102 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s + +// Test simple operations with no regions. They should be visited with stage +// = before all regions. + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{4: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{3: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{3: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + }) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with non empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{7: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) : (tensor) -> tensor + // expected-remark@below {{6: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with multiple regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{10: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{7: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) : (tensor) -> tensor + // expected-remark@below {{9: before all regions}} + return %0 : tensor +} + +// ----- +// Test static filtering +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{10: after all regions}} +func @foo(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: after all regions}} + // expected-remark@below {{11: before all regions}} + // expected-remark@below {{12: before region #1}} + // expected-remark@below {{13: after all regions}} + %0 = "tf.IfRegion"(%arg1) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{7: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }) {is_stateless = true}: (tensor) -> tensor + // expected-remark@below {{9: before all regions}} + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index de73dff8b0b..fe0c5bea44e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h" - #include #include #include @@ -41,7 +39,48 @@ namespace mlir { namespace TF { namespace { -// Replace TF BatchMatMul by TF Einsum + +// Replace TF BatchMatMul by TF Einsum op +template +class ConvertTFBatchMatMulToEinsumOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BatchMatMulOpType op, + PatternRewriter& rewriter) const override { + Value input_lhs = op.x(); + Value input_rhs = op.y(); + + // LHS and RHS must be a ranked tensor type + auto lhs_type = input_lhs.getType().dyn_cast(); + auto rhs_type = input_rhs.getType().dyn_cast(); + + if (!lhs_type || !rhs_type) return failure(); + + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + // Ensure that input ranks are at least 2. + const int dims_a = lhs_shape.size(); + const int dims_b = rhs_shape.size(); + if (dims_a < 2 || dims_b < 2) { + return failure(); + } + + // einsum equation for batchmatmul + std::string equation("...mk,...kn->...mn"); + if (op.adj_x()) std::swap(equation[3], equation[4]); + if (op.adj_y()) std::swap(equation[6 + 3], equation[6 + 4]); + + rewriter.replaceOpWithNewOp( + op, op.getType(), + /*inputs=*/ValueRange({input_lhs, input_rhs}), + /*equation=*/equation); + + return success(); + } +}; + struct BatchMatMulToEinsumPass : public PassWrapper { void runOnFunction() override; @@ -57,65 +96,10 @@ void BatchMatMulToEinsumPass::runOnFunction() { applyPatternsAndFoldGreedily(func, patterns); } -} // namespace - -template -LogicalResult -ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( - BatchMatMulOpType op, PatternRewriter& rewriter) const { - Value input_lhs = op.x(); - Value input_rhs = op.y(); - - if (!input_lhs.getType().isa()) { - // LHS must be a ranked tensor type - return failure(); - } - if (!input_rhs.getType().isa()) { - // RHS must be a ranked tensor type - return failure(); - } - - auto lhs_type = input_lhs.getType().dyn_cast(); - auto rhs_type = input_rhs.getType().dyn_cast(); - - if (!lhs_type || !rhs_type) { - return failure(); - } - - auto lhs_shape = lhs_type.getShape(); - auto rhs_shape = rhs_type.getShape(); - - Location loc = op.getLoc(); - - // Ensure that input ranks are at least 2. - const int dims_a = lhs_shape.size(); - const int dims_b = rhs_shape.size(); - if (dims_a < 2 || dims_b < 2) { - // Both inputs must have rank >= 2 - return failure(); - } - - // einsum equation for batchmatmul - std::string equation("...mk,...kn->...mn"); - - if (op.adj_x()) { - std::swap(equation[3], equation[4]); - } - if (op.adj_y()) { - std::swap(equation[6 + 3], equation[6 + 4]); - } - - llvm::SmallVector inputs = {input_lhs, input_rhs}; - rewriter.replaceOpWithNewOp(op, op.getType(), - /*inputs=*/ValueRange(inputs), - /*equation=*/equation); - - return success(); -} - -static PassRegistration pass( +PassRegistration pass( "tf-batch-matmul-to-tf-einsum", "Replace TF BatchMatMul op by TF Einsum op."); +} // namespace std::unique_ptr> CreateBatchMatMulToEinsumPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h deleted file mode 100644 index d39f3575b4a..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/util/matmul_bcast.h" - -namespace mlir { -namespace TF { - -// Replace TF BatchMatMul by TF Einsum op -template -class ConvertTFBatchMatMulToEinsumOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - BatchMatMulOpType op, - PatternRewriter& rewriter) const override; // NOLINT -}; - -} // namespace TF -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index ed0528ae054..358963a79e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -47,7 +47,8 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); add_pass(TFDevice::CreateParallelizeEmbeddingParamsOpsPass()); - add_pass(TFDevice::CreateReplicateToIslandPass()); + pm.addPass(TFDevice::CreateReplicateToIslandPass()); + pm.addPass(CreateBreakUpIslandsPass()); add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); add_pass(TFDevice::CreateLaunchToDeviceAttributePass()); } @@ -85,8 +86,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { // Encode this in its own scope so that func_pm is not mistakenly used // later on. { + pm.addPass(CreateTPUClusterFormationPass()); OpPassManager &func_pm = pm.nest(); - func_pm.addPass(CreateTPUClusterFormationPass()); // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass // because DecomposeResourceOpsPass uses pattern rewriter which hoists // changed constants out of tf_device.Launch. @@ -94,26 +95,32 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { func_pm.addPass(CreateTPUHostComputationExpansionPass()); func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); } - pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); - pm.addPass(mlir::createInlinerPass()); - pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); - pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); - // Run another shape inference pass because resource decomposition might have // created new partial types. pm.addPass(TF::CreateTFShapeInferencePass()); - pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); + pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(CreateTPUClusterCleanupAttributesPass()); pm.addPass(TFDevice::CreateResourceOpLiftingPass()); + pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass()); + pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(CreateTPUExtractOutsideCompilationPass()); + + pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPUDynamicPaddingMapperPass()); + pm.addPass(CreateTPUResourceReadForWritePass()); pm.addPass(CreateTPUShardingIdentificationPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addPass(CreateTPURewritePass()); pm.addPass(createSymbolDCEPass()); pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); pm.addNestedPass(CreateTPUDynamicLayoutPass()); + pm.addNestedPass(CreateTPUParallelExecuteSinkResourceWritePass()); pm.addNestedPass(CreateTPUMergeVariablesWithExecutePass()); + pm.addNestedPass(CreateTPUColocateCompositeResourceOps()); + pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); pm.addPass(CreateTPUVariableReformattingPass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 2b8ab85be38..e85058a1964 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -39,6 +39,10 @@ namespace { struct ClusterFormationPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 57a5cd888a1..cde07503e75 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -181,14 +181,14 @@ llvm::Optional GetElementTypeFromAccess( llvm::function_ref(Operation*)> infer_from_op) { for (auto& use : collection.getUses()) { if (auto while_op = llvm::dyn_cast(use.getOwner())) { - auto body = while_op.body_func(); + auto body = while_op.body_function(); assert(body); auto type_from_body = GetElementTypeFromAccess( body.getArgument(use.getOperandNumber()), module, infer_from_op); if (type_from_body.hasValue()) return type_from_body; } else if (auto if_op = llvm::dyn_cast(use.getOwner())) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); assert(then_branch && else_branch); auto type_from_then = GetElementTypeFromAccess( then_branch.getArgument(use.getOperandNumber() - 1), module, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 1429e2b3fd4..3005c78c54f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" @@ -68,7 +69,7 @@ static bool ShouldBeFolded(Operation* inst) { LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, - SmallVectorImpl& results) { // NOLINT + SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) @@ -126,8 +127,16 @@ LogicalResult ConstantFoldFallbackHook( // TODO(jpienaar): Avoid using global context & mutex here. static auto* mu = new tensorflow::mutex(); tensorflow::mutex_lock l(*mu); - return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); + SmallVector constants; + LogicalResult status = + tensorflow::EvaluateOperation(inst, inputs, ctx, &constants); + results.assign(constants.begin(), constants.end()); + return status; } +static bool init_hooks = ([] () { + TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook); +}(), true); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index 69e39080965..887eea745e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -27,7 +27,7 @@ namespace TF { LogicalResult ConstantFoldFallbackHook( Operation *inst, ArrayRef operands, - SmallVectorImpl &results); // NOLINT + SmallVectorImpl &results); // NOLINT } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc new file mode 100644 index 00000000000..b5d09f7a794 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/contraction_fusion.cc @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +// -------------------------------------------------------------------------- // +// Fuse ContractionFusableInterface operations into contraction operation. +// -------------------------------------------------------------------------- // + +template +class FuseIntoContractionOp : public RewritePattern { + public: + FuseIntoContractionOp() + : RewritePattern(PatternBenefit(1), MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto fusable = dyn_cast(op); + if (!fusable) return failure(); + + auto failed = [&](Twine message) -> LogicalResult { + return rewriter.notifyMatchFailure(op, message); + }; + + // Check if the operation can be fused. + Optional fusion = fusable.GetContractionFusion(); + if (!fusion.hasValue()) { + return failed("returned empty contraction fusion specification"); + } + + // Check if preceeding operation is a BaseOp or FusedOp that we can use for + // fusion. + Operation *fuse_into = nullptr; + Value operand = op->getOperand(0); + + if (BaseOp base_op = operand.getDefiningOp()) { + fuse_into = base_op.getOperation(); + } else if (FusedOp fused_op = operand.getDefiningOp()) { + fuse_into = fused_op.getOperation(); + } else { + return failed("input to the fusable op must be a " + + BaseOp::getOperationName() + " or a " + + FusedOp::getOperationName()); + } + + // Operand result must have one use, because we do not want to compute + // tensor contraction twice. + if (!fuse_into->getResult(0).hasOneUse()) { + return failed("fused into op result must have one use"); + } + + MLIRContext *ctx = op->getContext(); + + // Build a fused MatMul operation from a base MatMul and a fusion. + SmallVector locations = {fuse_into->getLoc(), op->getLoc()}; + Location loc = rewriter.getFusedLoc(locations); + + // Fusion can't change the type of a fused operation. + Type result_ty = fuse_into->getResult(0).getType(); + + // Copy all operands from a base op and add additional fusion arguments. + SmallVector operands(fuse_into->getOperands()); + for (int idx : fusion->additional_arguments) { + operands.push_back(op->getOperand(idx)); + } + + // Copy attributes from a base op that we fuse into (e.g. copy all + // MatMul or Conv attributes to the fused operation). + SmallVector attrs(fuse_into->getAttrs().begin(), + fuse_into->getAttrs().end()); + + // Add fusion specific additional attributes. + for (auto attr : fusion->additional_attributes) { + attrs.push_back(attr); + } + + // Add a fused output kernel name to the list of fusions. + Identifier fusion_id = Identifier::get("fusion", ctx); + StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx); + + auto is_fusion = [&](const NamedAttribute &attr) -> bool { + return attr.first == fusion_id; + }; + + if (isa(fuse_into)) { + NamedAttribute fusion_attr(fusion_id, ArrayAttr::get({fusion_name}, ctx)); + attrs.push_back(fusion_attr); + + } else { + ArrayAttr arr = + llvm::find_if(attrs, is_fusion)->second.template cast(); + llvm::erase_if(attrs, is_fusion); + + auto rng = arr.getAsRange(); + SmallVector updated(rng.begin(), rng.end()); + updated.push_back(fusion_name); + + attrs.push_back(NamedAttribute(fusion_id, ArrayAttr::get(updated, ctx))); + } + + // Update all uses of a fusable op with a new fused operation. + Value fused = rewriter.create(loc, result_ty, operands, attrs); + rewriter.replaceOp(op, {fused}); + + return failure(); + } +}; + +// -------------------------------------------------------------------------- // + +using FuseIntoMatMulOp = FuseIntoContractionOp; + +struct ContractionFusionPass + : public PassWrapper { + void runOnFunction() override; +}; + +void ContractionFusionPass::runOnFunction() { + FuncOp func = getFunction(); + + OwningRewritePatternList patterns; + patterns.insert(); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +std::unique_ptr> CreateContractionFusionPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-contraction-fusion", + "Fuses operations implementing ContractionFusionInterface into the " + "contraction operations"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc similarity index 74% rename from tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 109ceea47e7..d309c6d379f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -19,7 +19,6 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/DialectHooks.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -35,31 +34,22 @@ namespace { // Since this method is passed to MLIR as decode hook it has to conform // to LLVM style used by MLIR. -bool DecodeOpaqueTensorHook(const OpaqueElementsAttr input, - ElementsAttr& output) { // NOLINT +LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input, + ElementsAttr& output) { // NOLINT Builder builder(input.getType().getContext()); auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder); if (!decoded_attr_or.ok()) { VLOG(2) << decoded_attr_or.status().error_message(); - return true; + return failure(); } output = decoded_attr_or.ValueOrDie(); - return false; + return success(); } -// Hooks for the TensorFlow dialect. -class TensorFlowHooks : public DialectHooks { - public: - DialectConstantFoldHook getConstantFoldHook() { - return TF::ConstantFoldFallbackHook; - } - DialectConstantDecodeHook getDecodeHook() { return DecodeOpaqueTensorHook; } -}; +static bool init_hooks = ([] () { + TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook); +}(), true); } // anonymous namespace - -// Static initialization for TensorFlow dialect hooks registration. -static DialectHooksRegistration tf_hooks_registration("tf"); - } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index 40339cebd31..4ed0307e2ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -85,7 +85,7 @@ def DecomposeResourceApplyMomentumOpNonNesterov : $var_resource, $accum_resource, $lr, $grad, $momentum, BoolAttr:$_, ConstBoolAttrFalse:$use_nesterov ), - [(TF_AddOp:$accum_new + [(TF_AddV2Op:$accum_new (TF_MulOp (CreateTFReadVariableOp $src_op, $grad, $accum_resource), $momentum @@ -107,7 +107,7 @@ def DecomposeResourceApplyMomentumOpNesterov : $var_resource, $accum_resource, $lr, $grad, $momentum, BoolAttr:$_, ConstBoolAttrTrue:$use_nesterov ), - [(TF_AddOp:$accum_new + [(TF_AddV2Op:$accum_new (TF_MulOp (CreateTFReadVariableOp $src_op, $grad, $accum_resource), $momentum @@ -117,7 +117,7 @@ def DecomposeResourceApplyMomentumOpNesterov : (TF_AssignVariableOp $accum_resource, $accum_new), (TF_AssignSubVariableOp $var_resource, - (TF_AddOp + (TF_AddV2Op (TF_MulOp $grad, $lr), (TF_MulOp $accum_new, (TF_MulOp $momentum, $lr)) ) @@ -175,7 +175,7 @@ def DecomposeResourceApplyKerasMomentumOpNesterov : ] >; -// Pattern to Decompose ResourceApplyAdagrad. +// Pattern to Decompose ResourceApplyAdagradV2. // This decomposition is only correct inside XLA as it ignores use_locking // attribute. // accum <- accum + grad * grad @@ -201,6 +201,21 @@ def DecomposeResourceApplyAdagradV2 : ] >; +// ResourceApplyAdagrad op can be canonicalized to ResourceApplyAdagradV2 with +// zero epsilon and then decomposed using DecomposeResourceApplyAdagradV2 +// pattern. +def DecomposeResourceApplyAdagrad : + Pattern< + (TF_ResourceApplyAdagradOp $var_resource, $accum_resource, $lr, $grad, + $use_locking, $update_slots), + [ + (TF_ConstOp:$zero_epsilon (GetScalarOfType<0> $grad)), + (TF_ResourceApplyAdagradV2Op $var_resource, $accum_resource, $lr, + $zero_epsilon, $grad, $use_locking, $update_slots + ) + ]>; + + // Pattern to Decompose ResourceApplyAdam without Nesterov momentum. // This decomposition is only correct inside XLA as it ignores use_locking // attribute. @@ -342,7 +357,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), [(TF_ConstOp:$one (GetScalarOfType<1> $grad)), (CreateTFReadVariableOp $src_op, $grad, $ms_resource), - (TF_AddOp:$ms_new + (TF_AddV2Op:$ms_new (TF_MulOp (TF_MulOp $grad, $grad), (TF_SubOp $one, $rho) @@ -354,7 +369,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), (TF_AssignVariableOp $ms_resource, $ms_new), // mg = grad * (one - rho) + mg * rho; - (TF_AddOp:$mg_new + (TF_AddV2Op:$mg_new (TF_MulOp $grad, (TF_SubOp $one, $rho) @@ -366,7 +381,7 @@ def DecomposeResourceApplyCenteredRMSProp : ), (TF_AssignVariableOp $mg_resource, $mg_new), // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) - (TF_AddOp:$mom_new + (TF_AddV2Op:$mom_new (TF_MulOp $momentum, (CreateTFReadVariableOp $src_op, $grad, $mom_resource)), (TF_DivOp @@ -374,7 +389,7 @@ def DecomposeResourceApplyCenteredRMSProp : (TF_SqrtOp (TF_SubOp $ms_new, - (TF_AddOp + (TF_AddV2Op (TF_MulOp $mg_new, $mg_new @@ -390,3 +405,45 @@ def DecomposeResourceApplyCenteredRMSProp : (TF_AssignSubVariableOp $var_resource, $mom_new) ] >; + +// This decomposition is only correct inside XLA as it ignores use_locking +// attribute. +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +def DecomposeResourceApplyRMSProp : + Pattern< + (TF_ResourceApplyRMSPropOp:$src_op + $var_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon, + $grad, ConstBoolAttrFalse:$use_locking + ), + [(TF_ConstOp:$one (GetScalarOfType<1> $grad)), + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + // ms <- rho * ms_{t-1} + (1-rho) * grad * grad + (TF_AddV2Op:$ms_new + (TF_MulOp + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + $rho + ), + (TF_MulOp + (TF_SquareOp $grad), + (TF_SubOp $one, $rho) + ) + ), + (TF_AssignVariableOp $ms_resource, $ms_new), + // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + (TF_AddV2Op:$mom_new + (TF_MulOp $momentum, + (CreateTFReadVariableOp $src_op, $grad, $mom_resource)), + (TF_DivOp + (TF_MulOp $lr, $grad), + (TF_SqrtOp + (TF_AddV2Op $ms_new, $epsilon) + ) + ) + ), + (TF_AssignVariableOp $mom_resource, $mom_new), + // var <- var - mom + (TF_AssignSubVariableOp $var_resource, $mom_new) + ] + >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index b47378762a9..cc24c98a786 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -240,7 +240,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { auto def_op = val.getDefiningOp(); #ifndef NDEBUG auto exec_dialect = - function.getContext()->getRegisteredDialect("tf_executor"); + function.getContext()->getLoadedDialect("tf_executor"); assert(def_op->getDialect() == exec_dialect && "unable to forward control dependencies"); #endif diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index d8678e620f4..a5d76619416 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -157,14 +157,14 @@ static LogicalResult LowerIfOp(IfOp op) { // Set up the 'then' block. Block* then_block = builder.createBlock(merge_block); - Operation* call_op = CallFn(loc, get_operand, op.then_func(), &builder); + Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder); auto get_then_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_then_result, merge_block, &builder); // Set up the 'else' block. Block* else_block = builder.createBlock(merge_block); - call_op = CallFn(loc, get_operand, op.else_func(), &builder); + call_op = CallFn(loc, get_operand, op.else_function(), &builder); auto get_else_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_else_result, merge_block, &builder); @@ -190,8 +190,8 @@ static LogicalResult LowerWhileOp(WhileOp op) { OpBuilder builder(op_inst); - auto cond_fn = op.cond_func(); - auto body_fn = op.body_func(); + auto cond_fn = op.cond_function(); + auto body_fn = op.body_function(); // Split the block containing the While op into two blocks. One containing // operations before the While op and other containing the rest. Create two diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index d23b977f0e3..87733bbbf3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project @@ -31,8 +32,8 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define DEBUG_TYPE "tf-functional-cf-to-region" @@ -53,8 +54,8 @@ struct FunctionalControlFlowToRegions // the input arguments are used as is (for IfOp) or block arguments of the same // type as the input arguments are created and then used as call arguments (for // While). -void CreateCall(Operation* op, FuncOp func, Region& caller_region, - ValueRange args, bool use_region_args) { +YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region, + ValueRange args, bool use_region_args) { assert(caller_region.empty() && "Expected empty region for newly created ops"); OpBuilder builder(caller_region); @@ -76,20 +77,31 @@ void CreateCall(Operation* op, FuncOp func, Region& caller_region, casted_args.push_back(arg); } auto call = builder.create(op->getLoc(), func, casted_args); - builder.create(op->getLoc(), call.getResults()); + return builder.create(op->getLoc(), call.getResults()); +} + +// Converts the condition for an IfOp/WhileOp to a boolean value. +Value ConvertConditionToBoolean(Operation* op, Value cond) { + if (auto ranked_type = cond.getType().dyn_cast()) + if (ranked_type.getRank() == 0 && + ranked_type.getElementType().isSignlessInteger(1)) + return cond; + + OpBuilder builder(op); + return builder.create(op->getLoc(), cond); } // Transform a functional IfOp to a region based IfRegionOp. LogicalResult ConvertIfOp(IfOp if_op) { + Value cond = ConvertConditionToBoolean(if_op, if_op.cond()); auto if_region = OpBuilder(if_op).create( - if_op.getLoc(), if_op.getResultTypes(), if_op.cond(), - if_op.is_stateless()); - CopyUnderscoredAttributes(if_op, if_region); + if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless()); + CopyDeviceAndUnderscoredAttributes(if_op, if_region); - CreateCall(if_op, if_op.then_func(), + CreateCall(if_op, if_op.then_function(), /*caller_region=*/if_region.then_branch(), if_op.input(), /*use_region_args=*/false); - CreateCall(if_op, if_op.else_func(), + CreateCall(if_op, if_op.else_function(), /*caller_region=*/if_region.else_branch(), if_op.input(), /*use_region_args=*/false); if_op.replaceAllUsesWith(if_region.getResults()); @@ -101,12 +113,17 @@ LogicalResult ConvertWhileOp(WhileOp while_op) { auto while_region = OpBuilder(while_op).create( while_op.getLoc(), while_op.getResultTypes(), while_op.input(), while_op.is_stateless(), while_op.parallel_iterations()); - CopyUnderscoredAttributes(while_op, while_region); + CopyDeviceAndUnderscoredAttributes(while_op, while_region); - CreateCall(while_op, while_op.cond_func(), - /*caller_region=*/while_region.cond(), while_op.input(), - /*use_region_args=*/true); - CreateCall(while_op, while_op.body_func(), + YieldOp cond_yield = + CreateCall(while_op, while_op.cond_function(), + /*caller_region=*/while_region.cond(), while_op.input(), + /*use_region_args=*/true); + Value i1_cond = + ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0)); + cond_yield.setOperand(0, i1_cond); + + CreateCall(while_op, while_op.body_function(), /*caller_region=*/while_region.body(), while_op.input(), /*use_region_args=*/true); while_op.replaceAllUsesWith(while_region.getResults()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index 175baeb627f..fbe0524ce8b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -91,7 +91,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { // Build the newly fused operation to replace the batch norm OperationState state(batch_norm.getLoc(), - FusedBatchNormExOp::getOperationName()); + _FusedBatchNormExOp::getOperationName()); state.addOperands(batch_norm.getOperands()); if (side_input) state.operands.push_back(side_input); state.addTypes(batch_norm.getResultTypes()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index bce18c0b4b7..4e507c8e760 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -104,10 +104,10 @@ LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect, } void LaunchToDeviceAttributePass::runOnFunction() { - const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); if (!tf_dialect) { - signalPassFailure(); getFunction().emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); } auto result = getFunction().walk([&](tf_device::LaunchOp launch) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index e76a8da0b29..8123f50757e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -33,6 +35,34 @@ namespace mlir { namespace TF { namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" + +// Helper method that returns an op from 'transpose_ops' that match criteria +// for an 'operand' and 'permutation' +TransposeOp ReuseExistingTranspose(const OpOperand* operand, + const SmallVector& permutation, + Operation* op, ConstOp permutation_op, + SmallVector* transpose_ops) { + for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) { + auto tranpose_op = *it; + for (auto tranpose_operand : tranpose_op.getOperands()) { + auto ranked_tranpose_type = + tranpose_operand.getType().dyn_cast_or_null(); + if (!ranked_tranpose_type) continue; + if (ranked_tranpose_type.getRank() == permutation.size() && + operand->get().getType() == + ShuffleRankedTensorType(ranked_tranpose_type, permutation)) { + TransposeOp transpose = tranpose_op; + transpose.getOperation()->moveBefore(op); + transpose.setOperand(0, operand->get()); + transpose.setOperand(1, permutation_op); + transpose_ops->erase(it); + return transpose; + } + } + } + return nullptr; +} // LayoutAssignmentPass assigns optimal data layout (data format) for all // layout sensitive operations. @@ -79,18 +109,7 @@ class MoveTransposesPass clEnumValN(Direction::kEnd, "end", "end of the block"))}; }; -using Permutation = SmallVector; - -Permutation GetDataFormatPermutation(StringRef from_data_format, - StringRef to_data_format) { - if (from_data_format == "NHWC" && to_data_format == "NCHW") { - return {0, 3, 1, 2}; - } else if (from_data_format == "NCHW" && to_data_format == "NHWC") { - return {0, 2, 3, 1}; - } else { - llvm_unreachable("Unknown data format combination"); - } -} +using Permutation = SmallVector; void LayoutAssignmentPass::runOnFunction() { FuncOp func = getFunction(); @@ -131,7 +150,7 @@ void LayoutAssignmentPass::runOnFunction() { OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock()); auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr { - auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32)); + auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64)); return DenseIntElementsAttr::get(perm_ty, permutation); }; @@ -202,6 +221,27 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Nothing to do here. if (!permutation_op || transpose_ops.empty()) return; + SmallVector permutation; + auto perm_attr = permutation_op.value().cast(); + for (const auto& value : perm_attr.getIntValues()) + permutation.push_back(value.getSExtValue()); + + // We want to make sure the shape of the operand equals the transposed shape. + // mismatch can happen if 'op' supports broadcasting and the operands have + // different ranks. + if (op->hasTrait()) { + auto transpose_op = *transpose_ops.begin(); + auto result_type = + transpose_op.getResult().getType().dyn_cast_or_null(); + auto is_valid_move = + llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool { + auto operand_type = operand.getType().dyn_cast_or_null(); + return result_type && operand_type && result_type.hasRank() && + operand_type.hasRank() && + result_type.getRank() == operand_type.getRank(); + }); + if (!is_valid_move) return; + } // At this point we checked that we can safely move Transpose node before // `op`, and bypass all result transposes. @@ -228,16 +268,12 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { work_list->push_back(operand_op); // Try to reuse result transposes. - TransposeOp transpose; - if (!transpose_ops.empty()) { - transpose = transpose_ops.pop_back_val(); - transpose.getOperation()->moveBefore(op); - transpose.setOperand(0, operand.get()); - transpose.setOperand(1, permutation_op); - } else { + TransposeOp transpose = ReuseExistingTranspose( + &operand, permutation, op, permutation_op, &transpose_ops); + // If no transpose available for using, create new one. + if (!transpose) transpose = builder.create(loc, operand.get(), permutation_op); - } operand.set(transpose); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index ad241ef9488..8e93a7e7470 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -88,7 +88,7 @@ class ConvertConvOp : public OpConversionPattern { const int input_channels = conv_op.lhs().getType().cast().getDimSize( input_feature_dimension); - int feature_group_count = conv_op.feature_group_count().getSExtValue(); + int feature_group_count = conv_op.feature_group_count(); const bool is_depthwise_conv = input_channels == feature_group_count; std::string padding; @@ -250,7 +250,7 @@ class ConvertSliceOp : public OpConversionPattern { strides.getSplatValue().cast().getInt() != 1) return failure(); - rewriter.setInsertionPointAfter(slice_op); + rewriter.setInsertionPointAfter(slice_op.getOperation()); auto start_indices = slice_op.start_indices(); auto limit_indices = slice_op.limit_indices(); std::vector size_values; @@ -614,7 +614,65 @@ class ConvertReduceOpToTfMin : public OpConversionPattern { }; }; +class ConvertIotaOpToTfRange : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::IotaOp iota_op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + RankedTensorType type = + iota_op.getType().dyn_cast_or_null(); + if (!type) return failure(); + + const uint64_t dimension = iota_op.iota_dimension(); + Type element_type = type.getElementType(); + Attribute start, limit, delta; + if (element_type.isa()) { + start = rewriter.getFloatAttr(element_type, 0.0); + limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]); + delta = rewriter.getFloatAttr(element_type, 1.0); + } else if (element_type.isa()) { + start = rewriter.getIntegerAttr(element_type, 0); + limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]); + delta = rewriter.getIntegerAttr(element_type, 1); + } else { + return failure(); + } + + auto range_type = + RankedTensorType::get({type.getShape()[dimension]}, element_type); + Value start_op = rewriter.create(iota_op.getLoc(), start); + Value limit_op = rewriter.create(iota_op.getLoc(), limit); + Value delta_op = rewriter.create(iota_op.getLoc(), delta); + Value result = rewriter.create(iota_op.getLoc(), range_type, + start_op, limit_op, delta_op); + + if (type.getRank() > 1) { + std::vector reshape_shape(type.getRank(), 1); + reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension]; + auto reshape_type = RankedTensorType::get(reshape_shape, element_type); + Value reshape_shape_op = rewriter.create( + iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape)); + result = rewriter.create(iota_op.getLoc(), reshape_type, + result, reshape_shape_op); + + Value broadcast_shape_op = rewriter.create( + iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape())); + result = rewriter.create(iota_op.getLoc(), type, + result, broadcast_shape_op); + } + + rewriter.replaceOp(iota_op, result); + return success(); + } +}; + class LegalizeHloToTf : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: LegalizeHloToTf() = default; LegalizeHloToTf(const LegalizeHloToTf &) {} @@ -765,7 +823,8 @@ void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, MLIRContext *context) { populateWithGenerated(context, patterns); patterns->insert(context); + ConvertReduceOpToTfMin, ConvertReduceOpToTfSum, + ConvertIotaOpToTfRange>(context); } std::unique_ptr> CreateLegalizeHloToTfPass() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index d67739a739b..f88488de27d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -55,18 +56,27 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, return DenseIntElementsAttr::get(ty, vals); } -// Returns int or float DenseElementsAttr with scalar shape with the given -// element type and the integer value. +// Returns int, float, or complex DenseElementsAttr with scalar shape with the +// given element type and the integer value. static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); if (auto float_ty = ty.dyn_cast_or_null()) { FloatAttr attr = FloatAttr::get(float_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto int_ty = ty.dyn_cast_or_null()) { + IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto complex_ty = ty.dyn_cast_or_null()) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } } - - auto int_ty = ty.cast(); - IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); - return DenseElementsAttr::get(scalar_ty, attr); + llvm_unreachable("unsupported type"); } // Returns float DenseElementsAttr with scalar shape with the specified value. @@ -111,34 +121,87 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { return RankedTensorType::get(shape, ranked_ty.getElementType()); } +// Converts individual Values to a tensor of rank 1. Each input Value has rank 1 +// and size 1. +Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype, + ArrayRef vals) { + int64_t length = vals.size(); + auto type = RankedTensorType::get({length}, dtype); + auto axis = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 0)); + return rewriter.create(loc, type, ValueRange(vals), axis); +} + // Lowers AddN op to a sequence of AddV2 ops to accumulate operands. // +// Note that to improve the parallelism, AddN op uses tree-based reduction. +// For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows: +// +// 0 1 2 3 4 +// | | | | | +// ------- ------- | +// | | | +// 5 6 | +// | | | +// ------------- | +// | | +// 7 | +// | | +// ---------------- +// | +// 8 +// +// Example: +// // %result = "tf.AddN"(%0, %1, %2) // // is lowered to: // -// %sum_0 = "tf.AddV2"(%0, %1) -// %result = "tf.AddV2"(%sum_0, %2) +// %sum0 = "tf.AddV2"(%0, %1) +// %result = "tf.AddV2"(%sum0, %2) // -class LowerAddNOp : public OpRewritePattern { +// While +// +// %result = "tf.AddN"(%0, %1, %2, %3, %4) +// +// is lowered to: +// +// %sum0 = "tf.AddV2"(%0, %1) +// %sum1 = "tf.AddV2"(%2, %3) +// %sum2 = "tf.AddV2"(%sum0, %sum1) +// %result = "tf.AddV2"(%sum2, %4) +// +class LowerAddNOp : public RewritePattern { public: explicit LowerAddNOp(MLIRContext *context) - : OpRewritePattern(context) {} + : RewritePattern(TF::AddNOp::getOperationName(), + {TF::AddV2Op::getOperationName()}, 1, context) {} - LogicalResult matchAndRewrite(TF::AddNOp op, + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + auto addn_op = cast(op); + // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't // support variant type so variant types require special handling. - if (getElementTypeOrSelf(op.getType()).isa()) return failure(); + if (getElementTypeOrSelf(addn_op.getType()).isa()) + return failure(); + llvm::SmallVector operands(addn_op.inputs().begin(), + addn_op.inputs().end()); - // TODO(hinsu): Improve parallelism by splitting operands in two halves and - // accumulating them first. - Value result = *op.inputs().begin(); - for (Value operand : llvm::drop_begin(op.inputs(), 1)) { - result = rewriter.create(op.getLoc(), result, operand); + int64_t n = operands.size(); + // Keep doing tree-based reduction when there are more than one operand. + while (n > 1) { + for (int64_t i = 0; i < n; i += 2) { + // Add two adjacent operands if applicable. + operands[i / 2] = + (i + 1 < n) ? rewriter.create( + addn_op.getLoc(), operands[i], operands[i + 1]) + : operands[i]; + } + n = (n + 1) / 2; } - rewriter.replaceOp(op, result); + rewriter.replaceOp(addn_op, operands[0]); return success(); } }; @@ -224,7 +287,7 @@ class LowerDynamicStitchOp : public OpRewritePattern { reshaped_data.getType().cast().getShape()[0]; auto items = rewriter.create( loc, SmallVector(num_items, item_ty), reshaped_data, - /*axis=*/APInt(64, 0)); + /*axis=*/0); for (auto index_item : llvm::zip(index_attr, items.getResults())) { int64_t output_index = std::get<0>(index_item).getSExtValue(); Value item = std::get<1>(index_item); @@ -320,7 +383,7 @@ class LowerPackOp : public OpRewritePattern { loc, DenseElementsAttr::get( RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis())); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); Type prev_input_ty, inferred_ty; SmallVector expanded_inputs; @@ -344,6 +407,187 @@ class LowerPackOp : public OpRewritePattern { } }; +// Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))). +// +// Before rewrite: +// output = SpaceToBatchND(input, block_shape, paddings) +// Let: +// [batch] + spatial_shape + remaining_shape = input.shape +// M = spatial_shape.rank +// After rewrite: +// padded = zero-pad input with paddings +// The spatial_shape component of input.shape pads with paddings[*, 0] +// before each dimension, and paddings[*, 1] after each dimension. +// reshaped = reshape padded to: +// [batch] +// + [padded.shape[1]/block_shape[0], block_shape[0], ..., +// padded.shape[M]/block_shape[M-1], block_shape[M-1]] +// + remaining_shape +// permuted = transpose reshaped to: +// block_shape +// + [batch] +// + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]] +// + remaining_shape +// result = reshape permuted to: +// [batch * product(block_shape)] +// + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]] +// + remaining_shape +class LowerSpaceToBatchNDOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SpaceToBatchNDOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto input_type = op.input().getType().cast(); + if (!input_type.hasStaticShape()) { + return failure(); + } + ArrayRef input_shape = input_type.getShape(); + auto block_shape_type = op.block_shape().getType().cast(); + if (!block_shape_type.hasStaticShape()) { + return failure(); + } + auto paddings_type = op.paddings().getType().cast(); + + int64_t input_rank = input_type.getRank(); + int64_t block_rank = block_shape_type.getNumElements(); + int64_t remaining_rank = input_rank - 1 - block_rank; + if (remaining_rank < 0) { + // TODO(b/157475606): Move this check to ::Verify + return failure(); + } + + auto block_shape_i64_type = RankedTensorType::get( + block_shape_type.getShape(), rewriter.getIntegerType(64)); + auto block_shape_i64 = rewriter.create( + loc, block_shape_i64_type, op.block_shape()); + + auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(), + rewriter.getIntegerType(64)); + auto paddings_i64 = + rewriter.create(loc, paddings_i64_type, op.paddings()); + + auto pad00 = rewriter.create( + loc, DenseElementsAttr::get( + RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)), + {0, 0})); + SmallVector full_paddings_list{pad00, paddings_i64}; + full_paddings_list.append(remaining_rank, pad00); + auto full_paddings_type = + RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64)); + auto zero_i64 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 0)); + // Extends paddings to all dimensions of input by adding 0s to non-block + // dimensions. + auto full_paddings = rewriter.create( + loc, full_paddings_type, full_paddings_list, zero_i64); + + SmallVector padded_shape(input_rank, ShapedType::kDynamicSize); + auto padded_type = + RankedTensorType::get(padded_shape, rewriter.getF32Type()); + // padded = pad(input, full_paddings) + auto padded = + rewriter.create(loc, padded_type, op.input(), full_paddings); + + auto paddings_sum_type = + RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)); + auto one_i64 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(64), 1)); + // paddings_sum = paddings[*,0] + paddings[*,1] + auto paddings_sum = rewriter.create(loc, paddings_sum_type, + full_paddings, one_i64); + + // input_shape_tensor = input.shape + auto input_shape_tensor = rewriter.create( + loc, + DenseElementsAttr::get( + RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)), + input_shape)); + + // padded_shape_tensor is the shape of padded. + auto padded_shape_tensor = + rewriter.create(loc, paddings_sum, input_shape_tensor); + + auto zero_i32 = rewriter.create( + loc, GetScalarOfType(rewriter.getIntegerType(32), 0)); + SmallVector padded_shape_splits_types( + input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64))); + SmallVector padded_shape_splits( + rewriter + .create(loc, padded_shape_splits_types, zero_i32, + padded_shape_tensor) + .output()); + + SmallVector block_shape_splits_types( + block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64))); + SmallVector block_shape_splits( + rewriter + .create(loc, block_shape_splits_types, zero_i32, + block_shape_i64) + .output()); + + SmallVector outer_shape_vals; + for (int64_t i = 0; i < block_rank; ++i) { + // TODO(b/157475606): Insert tf.Assert that the following division has + // remainder 0. + outer_shape_vals.push_back(rewriter.create( + loc, padded_shape_splits[1 + i], block_shape_splits[i])); + } + + SmallVector reshaped_shape_vals{padded_shape_splits[0]}; + for (int64_t i = 0; i < block_rank; ++i) { + reshaped_shape_vals.push_back(outer_shape_vals[i]); + reshaped_shape_vals.push_back(block_shape_splits[i]); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + reshaped_shape_vals.push_back(padded_shape_splits[i]); + } + auto reshaped_shape = ValuesToRank1( + rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals); + + SmallVector permutation_vals; + for (int64_t i = 0; i < block_rank; ++i) { + permutation_vals.push_back(2 + 2 * i); + } + permutation_vals.push_back(0); + for (int64_t i = 0; i < block_rank; ++i) { + permutation_vals.push_back(1 + 2 * i); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + permutation_vals.push_back(block_rank + i); + } + auto permutation = rewriter.create( + loc, GetI64ElementsAttr(permutation_vals, &rewriter)); + + auto output_batch = padded_shape_splits[0]; + for (int64_t i = 0; i < block_rank; ++i) { + output_batch = + rewriter.create(loc, output_batch, block_shape_splits[i]); + } + SmallVector output_shape_vals{output_batch}; + for (int64_t i = 0; i < block_rank; ++i) { + output_shape_vals.push_back(outer_shape_vals[i]); + } + for (int64_t i = 1 + block_rank; i < input_rank; ++i) { + output_shape_vals.push_back(padded_shape_splits[i]); + } + auto output_shape = ValuesToRank1( + rewriter, loc, rewriter.getIntegerType(64), output_shape_vals); + auto reshaped = rewriter.create(loc, padded, reshaped_shape); + auto permuted = + rewriter.create(loc, reshaped, permutation); + + // Sometimes the result type is more specific than what the reshape builder + // can infer. + auto result_type = op.getResult().getType(); + rewriter.replaceOpWithNewOp(op, result_type, permuted, + output_shape); + + return success(); + } +}; + // Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints, // since we currently don't have an implementation that can use this // information. Adds appropriate casts where necessary to align element types @@ -388,12 +632,37 @@ class LowerSparseMatMulOp : public OpRewritePattern { } }; +// Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that +// were fused together. +class Lower_UnaryOpsComposition + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op, + PatternRewriter &rewriter) const override { + Value result = op.x(); + for (StringRef op_name : op.op_names().getAsValueRange()) { + std::string full_name = "tf." + op_name.str(); + // All ops in the sequences have the same result type as the original + // result type. + OperationState state(op.getLoc(), full_name, /*operands=*/{result}, + /*types=*/{op.getType()}, /*attributes=*/{}); + Operation *op = rewriter.createOperation(state); + result = op->getResult(0); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + } // namespace void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + LowerPackOp, LowerSpaceToBatchNDOp, LowerSparseMatMulOp, + Lower_UnaryOpsComposition>(context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 6b7d7178ab6..f7a867f3130 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -195,8 +195,7 @@ def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings), // Reciprocal op patterns. //===----------------------------------------------------------------------===// -// TODO(hinsu): Support complex and unsigned input types. -def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x), +def LowerReciprocal : Pat<(TF_ReciprocalOp $x), (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 72f7a3a438c..25bd53ee73c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -17,12 +17,17 @@ limitations under the License. #include #include +#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" namespace mlir { namespace TFDevice { @@ -30,6 +35,7 @@ namespace TFDevice { namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; +constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement"; // This pass marks unsupported ops in a device cluster with // `_xla_outside_compilation` attribute so the operations will run on the host @@ -41,6 +47,36 @@ struct MarkOpsForOutsideCompilation void runOnOperation() override; }; +// Adds any canonicalization patterns to list of supported `patterns`. +// TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and +// remove this. +void AddCanonicalizationPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + for (auto* op : context->getRegisteredOperations()) + op->getCanonicalizationPatterns(*patterns, context); +} + +// TODO(b/159128666): Check the control flow legalization passes instead once +// added. +void AddSupportedControlFlowOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { + supported_ops->insert( + OperationName(TF::IfRegionOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::WhileRegionOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::YieldOp::getOperationName(), context)); +} + +// These embedding ops are rewritten when running TPUCompileOp. +void AddRewrittenEmbeddingOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { + supported_ops->insert(OperationName( + TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context)); + supported_ops->insert(OperationName( + TF::SendTPUEmbeddingGradientsOp::getOperationName(), context)); +} + bool HasStringOperand(Operation& op) { for (auto operand : op.getOperands()) { if (getElementTypeOrSelf(operand).isa()) return true; @@ -55,29 +91,143 @@ bool HasStringResult(Operation& op) { return false; } -// Checks if the op is supported inside of a device cluster. -bool IsSupportedOp(Operation& op) { - if (HasStringOperand(op) || HasStringResult(op)) { - return false; - } - return true; +bool MatchesPattern(Operation& op, + const llvm::DenseSet& supported_ops) { + return (supported_ops.contains(op.getName())); } -LogicalResult MarkUncompilableOps(Block* block) { - for (Operation& op : *block) { - if (!IsSupportedOp(op)) { - op.setAttr(kXlaOutsideCompilationAttr, - StringAttr::get("auto", op.getContext())); - } +// Checks if the op is supported inside of a device cluster. Ops not +// in `tf_dialect` are considered supported. +bool IsSupportedOp(Operation& op, + const llvm::DenseSet& supported_ops, + const Dialect* tf_dialect) { + if (op.getDialect() != tf_dialect) + return true; + else + return !HasStringOperand(op) && !HasStringResult(op) && + (MatchesPattern(op, supported_ops) || + mhlo::IsOpAllowedTf2XlaFallback(&op)); +} + +// Checks all regions of `op` for captured string operands. +bool HasCapturedStringOperand(Operation* op) { + bool string_operand = false; + for (auto& region : op->getRegions()) { + mlir::visitUsedValuesDefinedAbove( + region, region, [&](mlir::OpOperand* operand) { + if (getElementTypeOrSelf(operand->get()).isa()) + string_operand = true; + }); + if (string_operand) return string_operand; } + return string_operand; +} + +// Marks uncompilable ops that are in `tf_dialect` for outside compilation. +LogicalResult MarkUncompilableOps( + const Dialect* tf_dialect, Block* block, + llvm::DenseSet& supported_ops) { + // Automatically marked ops for outside compilation have + // `_xla_outside_compilation` attribute value of "auto" plus + // an increasing counter. Manually marked ops for outside compilation only + // have an increasing counteri for the attribute value. Therefore there is no + // collision in + // `_xla_outside_compilation` attribute between automatically and manually + // marking ops. + int outside_compiled_cluster_counter = 0; + block->walk([&](Operation* op) { + if (!IsSupportedOp(*op, supported_ops, tf_dialect)) { + op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get( + llvm::formatv("auto{0}", outside_compiled_cluster_counter).str(), + op->getContext())); + outside_compiled_cluster_counter++; + } + if (llvm::isa(op)) { + if (HasCapturedStringOperand(op)) { + op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get( + llvm::formatv("auto{0}", outside_compiled_cluster_counter) + .str(), + op->getContext())); + outside_compiled_cluster_counter++; + } + } + }); return success(); } +// Unmarks outside compilation for any op that has parents already +// marked for outside compilation since the child will be extracted +// anyways. +void UnmarkChildren(Block* block) { + block->walk([&](Operation* op) { + if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) return; + Operation* iter_op = op; + bool remove_attr = false; + while (auto* parent_op = iter_op->getParentOp()) { + if (parent_op->getAttrOfType(kXlaOutsideCompilationAttr)) { + remove_attr = true; + break; + } + iter_op = parent_op; + } + if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr); + }); +} + void MarkOpsForOutsideCompilation::runOnOperation() { auto module = getOperation(); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); + if (!tf_dialect) { + getOperation().emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); + } + OwningRewritePatternList patterns; + mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); + TF::PopulateLoweringTFPatterns(module.getContext(), &patterns); + AddCanonicalizationPatterns(module.getContext(), &patterns); + + // `supported_ops` contains the name of all of the ops that can potentially be + // lowered into HLO on the device. This doesn't always mean that the op can + // be lowered in the future passes but if the op is not in this set, it can't + // be lowered in a subsequent pass. + llvm::DenseSet supported_ops; + for (auto& pattern : patterns) { + Optional root_kind = pattern->getRootKind(); + if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue()); + } + AddSupportedControlFlowOps(module.getContext(), &supported_ops); + AddRewrittenEmbeddingOps(module.getContext(), &supported_ops); + + auto result = module.walk([&](tf_device::ClusterOp cluster) { + // Only if `allow_soft_placement` attribute is true should we mark ops + // for outside compilation. + auto soft_placement_attr = + cluster.getAttrOfType(kAllowSoftPlacementAttr); + if (!(soft_placement_attr && soft_placement_attr.getValue())) { + return WalkResult::advance(); + } + if (failed( + MarkUncompilableOps(tf_dialect, &cluster.GetBody(), supported_ops))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); module.walk([&](tf_device::ClusterOp cluster) { - MarkUncompilableOps(&cluster.GetBody()); + // Only if `allow_soft_placement` attribute is true should we unmark ops + // for outside compilation. + auto soft_placement_attr = + cluster.getAttrOfType(kAllowSoftPlacementAttr); + if (!(soft_placement_attr && soft_placement_attr.getValue())) { + return; + } + UnmarkChildren(&cluster.GetBody()); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 6fee693554e..b81e390580d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -109,13 +109,14 @@ class ResourceAnalyzer { return; } if (auto if_op = dyn_cast(op)) { - for (auto callee : {if_op.then_func(), if_op.else_func()}) { + for (auto callee : {if_op.then_function(), if_op.else_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input()); } return; } if (auto while_op = dyn_cast(op)) { - for (auto callee : {while_op.cond_func(), while_op.body_func()}) { + for (auto callee : + {while_op.cond_function(), while_op.body_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); } return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc index 527af0934ea..352604955c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc @@ -39,6 +39,10 @@ namespace { struct ParallelizeEmbeddingParamsOpsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 3afadd2b06d..a4ddb713ec0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -79,6 +79,11 @@ std::unique_ptr> CreateRewriteTPUEmbeddingOpsPass(); // Performs specific fusion for GPU targets. std::unique_ptr> CreateGpuOpFusionPass(); +// Create a pass that convert ops that copy tensors between devices, e.g. +// tf.Identity. +std::unique_ptr> +CreateTensorDeviceCopyConversionPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{ @@ -162,6 +167,12 @@ void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList* patterns, // future these fusions may be codegen'd automatically. std::unique_ptr> CreateFusedKernelMatcherPass(); +// Fuses operations defining `ContractionFusableInterface` interface into the +// contraction operations (MatMul, Conv2D, etc...). This is a more general +// version of `CreateFusedKernelMatcherPass` that relies on codegen to compose +// contraction fusions together. +std::unique_ptr> CreateContractionFusionPass(); + // Creates function pass to select device index/fold tf.DeviceIndex. std::unique_ptr> CreateDeviceIndexSelectorPass(); @@ -239,7 +250,7 @@ std::unique_ptr> CreateReplicateInvariantOpHoistingPass(); // Creates a pass that forms replica `tf_executor.island` from a single // `tf_device.replicate` island. -std::unique_ptr> CreateReplicateToIslandPass(); +std::unique_ptr> CreateReplicateToIslandPass(); // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. @@ -269,7 +280,15 @@ std::unique_ptr> CreateLaunchToDeviceAttributePass(); namespace TFTPU { // Creates a pass that forms clusters from operations of the same // `_tpu_replicate` attribute. -std::unique_ptr> CreateTPUClusterFormationPass(); +std::unique_ptr> CreateTPUClusterFormationPass(); + +// Creates a pass that cleans up `_tpu_replicate` attribute on operations +// that are inside a cluster. +std::unique_ptr> +CreateTPUClusterCleanupAttributesPass(); + +// Creates a pass that removes Identity/IdentityN ops from a cluster. +std::unique_ptr> CreateTPUIdentityPruningPass(); // Creates a pass that allows TPU program inputs to have layouts determined at // run time. @@ -279,6 +298,10 @@ std::unique_ptr> CreateTPUDynamicLayoutPass(); // `tf_device.launch_func` `padding_map` attribute to its encapsulated function. std::unique_ptr> CreateTPUDynamicPaddingMapperPass(); +// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources +// the cluster only writes to. +std::unique_ptr> CreateTPUResourceReadForWritePass(); + // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime // ops. std::unique_ptr> CreateTPURewritePass(); @@ -287,18 +310,29 @@ std::unique_ptr> CreateTPURewritePass(); // computation. std::unique_ptr> CreateTPUShardingIdentificationPass(); +// Creates a pass that moves `tf.AssignVariableOp` into a +// `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the +// only consumer of a `tf_device.parallel_execute` result. +std::unique_ptr> +CreateTPUParallelExecuteSinkResourceWritePass(); + // Creates a pass that merges device variable reads/updates into the surrounded // TPUExecute node. This allows the execute node to perform in-place variable // updates. std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); +// Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a +// packed tensor to have same device placement as underlying TPU device. +std::unique_ptr> CreateTPUColocateCompositeResourceOps(); + // Creates a pass that adds ops which perform formatting on variables at // run-time according to compilation result. std::unique_ptr> CreateTPUVariableReformattingPass(); // Creates a pass that groups outside compiled operations (CPU ops inside TPU // cluster) into clusters that can be extracted and run on the CPU. -std::unique_ptr> CreateTPUOutsideCompilationClusterPass(); +std::unique_ptr> +CreateTPUOutsideCompilationClusterPass(); // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) // at head/tail of TPU cluster to run before/after TPU computation. @@ -321,6 +355,7 @@ std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); // Populates the supplied passmanager with the passes required to run the +// bridge. void CreateTPUBridgePipeline(OpPassManager& pm); // Populates the supplied passmanager with the passes required to run the diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index ba876e08fbb..1e403bff0eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -36,8 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define DEBUG_TYPE "tf-region-cf-to-functional" @@ -158,9 +158,11 @@ void ExtractSingleBlockRegion(Region& region, StringRef name, } // Returns call for region with single call whose result feeds into the -// terminator of the region. Returns none if the region doesn't contain just -// call and non-truncting casts ops. -llvm::Optional IsSingleCallRegion(Region& region) { +// terminator of the region. if `allow_to_bool` is true, also allows a single +// ToBoolOp between the region yield and the call. Returns none if the region +// does not conform to this pattern. +llvm::Optional IsSingleCallRegion(Region& region, + bool allow_to_bool = false) { if (!llvm::hasSingleElement(region)) return llvm::None; Block& block = region.front(); @@ -169,31 +171,44 @@ llvm::Optional IsSingleCallRegion(Region& region) { if (it == block.rend()) return llvm::None; + // Operation which is expected to consume all the call results. + Operation* call_consumer = yield; + + // Allow a single ToBoolOp between the call and the yield (valid only + // when the yield has a single operand) + if (allow_to_bool && yield.getNumOperands() == 1 && isa(*it)) { + if (it->getResult(0) != yield.getOperand(0)) return llvm::None; + call_consumer = cast(*it); + it++; + } + // Check if there is a Call before the Yield. CallOp call = dyn_cast(*it++); if (!call) return llvm::None; + // All call results should feed into expected consumer + // All results of the call should feed into the yield. + if (call.getNumResults() != call_consumer->getNumOperands()) + return llvm::None; + + for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands())) + if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; + // There can only be non-truncating cast op's prior to the call. for (; it != block.rend(); ++it) { CastOp cast = dyn_cast(*it); if (!cast || cast.Truncate()) return llvm::None; } - // All results of the call should feed into the yield. - if (call.getNumResults() != yield.getNumOperands()) return llvm::None; - - for (auto res_it : llvm::zip(call.getResults(), yield.getOperands())) - if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; - return call; } -using MatcherFn = function_ref; +using ArgMatcherFn = function_ref; // Returns whether the arguments of the given 2 calls are match (after looking // through cast ops). `matcher` is the predicate used to check if two arguments // match. -bool MatchCallArgs(CallOp first, CallOp second, MatcherFn matcher) { +bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) { if (first.getNumOperands() != second.getNumOperands()) return false; Region& first_region = *first.getParentRegion(); @@ -225,38 +240,37 @@ struct TrivialTransformInfo { // List of callee names (one for each region). llvm::SmallVector callee_names; - // Constructor will analyze the 2 regions. - TrivialTransformInfo(Region& first, Region& second, MatcherFn matcher); + // Analyzes the given calls (from regions attached to the same parent op) to + // check if the parent op be transformed to functional form trivially (i.e., + // reusing existing functions and without outlining). This is possible when + // all the regions are single call regions (checked using matchers outside + // this class) and the all the calls match using the given argument matcher. + // + // If such a trivial transformation is possible, stash the relevant + // information needed for the transformation, else indicate that a trivial + // transformation is not possible by setting `can_transform` to false. + TrivialTransformInfo(llvm::Optional first_call, + llvm::Optional second_call, + ArgMatcherFn arg_matcher) { + if (!first_call || !second_call) return; + + if (!MatchCallArgs(first_call.getValue(), second_call.getValue(), + arg_matcher)) + return; + + can_transform = true; + callee_names = {first_call.getValue().getCallee(), + second_call.getValue().getCallee()}; + } }; -// Analyzes the given set of regions (attached to the same parent op) to check -// if the parent op be transformed to functional form trivially (i.e., reusing -// existing functions and without outlining). This is possible when all the -// regions are single call regions and the all the calls have the same -// arguments. -// -// If such a trivial transformation is possible, stash the relevant information -// needed for the transformation, else indicate that a trivial transformation is -// not possible by setting `can_transform` to false. -TrivialTransformInfo::TrivialTransformInfo(Region& first, Region& second, - MatcherFn matcher) { - auto call0 = IsSingleCallRegion(first); - auto call1 = IsSingleCallRegion(second); - if (!call0 || !call1) return; - - if (!MatchCallArgs(call0.getValue(), call1.getValue(), matcher)) return; - - can_transform = true; - callee_names = {call0.getValue().getCallee(), call1.getValue().getCallee()}; -} - // Transform IfRegionOp to IfOp. LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { llvm::SmallVector extern_values; // For IfOp, arguments of calls in the then and else regions match if they // are the same value. - auto if_matcher = [&](Value first, Region&, Value second, Region&) { + auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) { if (first != second) return false; // collect the call arguments post lookup through cast Op's @@ -264,8 +278,9 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { return true; }; - const TrivialTransformInfo tti(if_region.then_branch(), - if_region.else_branch(), if_matcher); + const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()), + IsSingleCallRegion(if_region.else_branch()), + if_arg_matcher); std::string then_name, else_name; @@ -293,16 +308,23 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { worklist, /*extern_values_passthrough=*/false); } + // Look through ToBool operations for the condition. + Value cond = if_region.cond(); + auto to_bool = dyn_cast_or_null(cond.getDefiningOp()); + if (to_bool) cond = to_bool.getOperand(); + // Once we have the `then` and `else` functions ready (either outlined or // existing ones), replace the region based op with a functional control flow // op. OpBuilder builder(if_region); auto if_op = builder.create( - if_region.getLoc(), if_region.getResultTypes(), if_region.cond(), - extern_values, then_name, else_name, if_region.is_stateless()); - CopyUnderscoredAttributes(if_region, if_op); + if_region.getLoc(), if_region.getResultTypes(), cond, extern_values, + then_name, else_name, if_region.is_stateless()); + CopyDeviceAndUnderscoredAttributes(if_region, if_op); if_region.replaceAllUsesWith(if_op.getResults()); if_region.erase(); + + if (to_bool && to_bool.use_empty()) to_bool.erase(); return success(); } @@ -315,8 +337,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( // cannot do a trivial transformation because post transform, we will need to // pass this extern value as an argument to the function, so we cannot use the // existing function as is. - auto while_matcher = [](Value first, Region& first_region, Value second, - Region& second_region) { + auto while_arg_matcher = [](Value first, Region& first_region, Value second, + Region& second_region) { if (!first.isa() || !second.isa()) return false; BlockArgument first_block_arg = first.cast(); @@ -329,8 +351,9 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( second_block_arg.getParentBlock() == &second_region.front(); }; - const TrivialTransformInfo tti(while_region.cond(), while_region.body(), - while_matcher); + const TrivialTransformInfo tti( + IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true), + IsSingleCallRegion(while_region.body()), while_arg_matcher); // All existing inputs to while region are inputs to the functional while. auto new_inputs = llvm::to_vector<4>(while_region.getOperands()); @@ -376,7 +399,7 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( auto while_op = builder.create( while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, while_region.parallel_iterations(), while_region.is_stateless()); - CopyUnderscoredAttributes(while_region, while_op); + CopyDeviceAndUnderscoredAttributes(while_region, while_op); // Redirect old results to new results. for (auto it : llvm::zip( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 031d57e99ba..96ff2890558 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -151,7 +151,7 @@ bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { // invariant. Shape ops are rewritten to be invariant when possible, prior to // hoisting ops. void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { - const int num_replicas = replicate_op.n().getLimitedValue(); + const int num_replicas = replicate_op.n(); Block* replicate_block = &replicate_op.GetBody(); replicate_op.walk([&](TF::ShapeOp shape_op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index b16868311f0..5b70729ee80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -32,12 +33,14 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/core/platform/logging.h" namespace mlir { @@ -45,10 +48,11 @@ namespace TFDevice { namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kReplicaIdAttr[] = "_xla_replica_id"; +constexpr char kDeviceOrdinalAttr[] = "device_ordinal"; struct ReplicateToIslandPass - : public PassWrapper { - void runOnFunction() override; + : public PassWrapper> { + void runOnOperation() override; }; // Returns whether op requires `_xla_replica_id` attribute. @@ -57,29 +61,207 @@ bool RequiresReplicaIDAttribute(Operation* op) { TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op); } -// Adds integer attribute that represents replica id for replicated ops that -// require replica id attribute. -void AddReplicaIdToOpsInReplicatedRegion(OpBuilder* builder, Region* region, - const int replica_id) { - region->walk([&](Operation* replicated_op) { - if (RequiresReplicaIDAttribute(replicated_op)) - replicated_op->setAttr(kReplicaIdAttr, - builder->getI32IntegerAttr(replica_id)); +bool RequiresDeviceOrdinalAttribute(Operation* op) { + return llvm::isa(op) || + llvm::isa(op); +} + +// Checks if a region contains ops that are replica variant. +bool HasReplicaVariantOps(Region& region, + const llvm::Optional& devices) { + auto result = region.walk([&](Operation* op) { + if (RequiresReplicaIDAttribute(op) || + (devices.hasValue() && RequiresDeviceOrdinalAttribute(op))) + return WalkResult::interrupt(); + + if (auto launch = dyn_cast(op)) + if (devices.hasValue() && devices.getValue().get(launch.device())) + return WalkResult::interrupt(); + + return WalkResult::advance(); }); + return result.wasInterrupted(); +} + +// Collects all functions reachable from a region, including transitive ones. +llvm::SmallPtrSet GetReachableFunctionsFromRegion(ModuleOp module, + Region& region) { + llvm::SmallPtrSet visited_functions; + + SymbolTable symbol_table(module); + auto symbol_uses = symbol_table.getSymbolUses(®ion); + if (!symbol_uses) return {}; + + for (auto& use : *symbol_uses) + if (auto func = + symbol_table.lookup(use.getSymbolRef().getRootReference())) + visited_functions.insert(func); + + llvm::SmallVector functions_to_visit(visited_functions.begin(), + visited_functions.end()); + while (!functions_to_visit.empty()) { + llvm::SmallVector new_functions_to_visit; + + for (FuncOp function_to_visit : functions_to_visit) { + auto func_symbol_uses = + symbol_table.getSymbolUses(function_to_visit.getCallableRegion()); + if (!func_symbol_uses) continue; + + for (auto& use : *func_symbol_uses) + if (auto func = symbol_table.lookup( + use.getSymbolRef().getRootReference())) + if (visited_functions.insert(func).second) + new_functions_to_visit.push_back(func); + } + + functions_to_visit.swap(new_functions_to_visit); + } + + return visited_functions; +} + +// Collects all functions and transitive functions reachable from region that +// contain replicate variant ops. +llvm::SmallDenseMap GetReachableFunctionsToClone( + ModuleOp module, Region& region, + const llvm::Optional& devices) { + llvm::SmallPtrSet reachable_functions = + GetReachableFunctionsFromRegion(module, region); + + llvm::SmallDenseMap functions_to_clone; + llvm::SmallVector functions_to_visit; + for (FuncOp func : reachable_functions) { + if (!func.getCallableRegion()) continue; + if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) { + functions_to_clone.insert({func.getName(), func}); + functions_to_visit.push_back(func); + } + } + + while (!functions_to_visit.empty()) { + llvm::SmallVector new_functions_to_visit; + + for (FuncOp func_to_visit : functions_to_visit) { + auto func_uses = func_to_visit.getSymbolUses(module); + if (!func_uses) continue; + for (auto use : *func_uses) { + auto parent_func = use.getUser()->getParentOfType(); + if (!parent_func || !reachable_functions.contains(parent_func) || + !functions_to_clone.insert({parent_func.getName(), parent_func}) + .second) + continue; + new_functions_to_visit.push_back(parent_func); + } + } + + functions_to_visit.swap(new_functions_to_visit); + } + + return functions_to_clone; +} + +struct FuncOldNameAndClone { + StringRef old_name; + FuncOp clone; +}; + +// Replaces all symbol uses with cloned functions, for `region` and across the +// cloned functions themselves. +LogicalResult UpdateSymbolUsesWithClones( + SymbolTable& symbol_table, ModuleOp module, Region& region, + llvm::MutableArrayRef cloned_functions) { + llvm::SmallVector, 4> old_to_new_names; + old_to_new_names.reserve(cloned_functions.size()); + for (auto& cloned_function : cloned_functions) + old_to_new_names.push_back( + {cloned_function.old_name, cloned_function.clone.getName()}); + + for (const auto& old_to_new_name : old_to_new_names) { + if (failed(symbol_table.replaceAllSymbolUses( + old_to_new_name.first, old_to_new_name.second, ®ion))) + return failure(); + + for (auto& cloned_function : cloned_functions) + if (failed(symbol_table.replaceAllSymbolUses( + old_to_new_name.first, old_to_new_name.second, + cloned_function.clone.getCallableRegion()))) + return failure(); + } + return success(); +} + +// Collects TPU device ordinal for outside compilation communication ops. This +// currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0` +// aliased device for the device computation. +llvm::Optional GetDeviceOrdinal( + const llvm::Optional& devices, Location loc, + unsigned replica_id) { + int64_t device_ordinal = 0; + if (devices.hasValue()) { + if (auto tpu_replica_0 = devices.getValue().get("TPU_REPLICATED_CORE_0")) { + llvm::StringRef tpu_device = tpu_replica_0.cast()[replica_id] + .cast() + .getValue(); + if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString( + loc, tpu_device, &device_ordinal))) { + return llvm::Optional(device_ordinal); + } + } + } + return llvm::None; +} + +// Updates replica variant ops in a region based on replica `replica_id`. +// TODO(b/157624749): Replace this with better abstraction to differentiate ops +// for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding +// ops, require replica id to be added as an op attribute to be used during +// execution. Handle such ops separately and add an integer attribute that +// represents replica id. +LogicalResult UpdateRegionReplicateVariantOps( + OpBuilder& builder, Location loc, Region& region, int replica_id, + llvm::MutableArrayRef cloned_functions, + const llvm::Optional& devices) { + llvm::Optional device_ordinal = + GetDeviceOrdinal(devices, loc, replica_id); + + auto update_replicate_variant_ops = [&](Operation* op) { + // Add replica id. + if (RequiresReplicaIDAttribute(op)) + op->setAttr(kReplicaIdAttr, builder.getI32IntegerAttr(replica_id)); + + if (!devices.hasValue()) return; + + // Map aliased devices to explicit devices based on replica. + if (auto launch = dyn_cast(op)) + if (auto device_by_replica = devices.getValue().get(launch.device())) + launch.setAttr( + kDeviceAttr, + device_by_replica.cast()[replica_id].cast()); + + // Add device ordinal. + if (device_ordinal && RequiresDeviceOrdinalAttribute(op)) + op->setAttr(kDeviceOrdinalAttr, + builder.getI64IntegerAttr(*device_ordinal)); + }; + + region.walk(update_replicate_variant_ops); + for (auto& cloned_function : cloned_functions) + cloned_function.clone.getCallableRegion()->walk( + update_replicate_variant_ops); + + return success(); } // Creates islands per replica from `tf_device.replicate` region. If for a // `tf_device.launch` op the device is an aliased device of the // `tf_device.replicate`, the device will be remapped to an explicit device // for the associated replica island. -llvm::SmallVector ExpandReplicateIntoReplicas( - const Dialect* tf_dialect, OpBuilder* builder, +LogicalResult ExpandReplicateIntoReplicas( + const Dialect* tf_dialect, OpBuilder& builder, ModuleOp module, tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op, - int num_replicas) { - auto devices = replicate_op.devices(); - const bool has_devices = devices.hasValue(); - llvm::SmallVector replicas; + int num_replicas, llvm::SmallVectorImpl& replicas) { replicas.reserve(num_replicas); + auto devices = replicate_op.devices(); // Collect result types and operands. Operation& terminator = replicate_op.GetBody().back(); @@ -88,16 +270,30 @@ llvm::SmallVector ExpandReplicateIntoReplicas( llvm::SmallVector replica_inputs(island_op.controlInputs()); // Replace replicate terminator with YieldOp. - builder->setInsertionPoint(&terminator); - builder->create(terminator.getLoc(), - terminator.getOperands()); + builder.setInsertionPoint(&terminator); + builder.create(terminator.getLoc(), + terminator.getOperands()); terminator.erase(); - builder->setInsertionPoint(island_op); + auto funcs_to_clone = + GetReachableFunctionsToClone(module, replicate_op.body(), devices); + SymbolTable symbol_table(module); + + builder.setInsertionPoint(island_op); BlockAndValueMapping mapping; for (int i : llvm::seq(0, num_replicas)) { + // Clone reachable functions with replica variant ops. + llvm::SmallVector cloned_functions; + cloned_functions.reserve(funcs_to_clone.size()); + for (auto& func_to_clone : funcs_to_clone) { + auto cloned_function = func_to_clone.getSecond().clone(); + symbol_table.insert(cloned_function, module.end()); + cloned_functions.push_back( + {func_to_clone.getSecond().getName(), cloned_function}); + } + // Create new island for replica. - auto replica = builder->create( + auto replica = builder.create( island_op.getLoc(), output_types, control_type, replica_inputs); // Map block arg to replica arg. @@ -109,28 +305,19 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Copy over replicate region into replica island. replicate_op.body().cloneInto(&replica.body(), mapping); - // TODO(b/157624749): Replace this with better abstraction to - // differentiate ops for different replicas. - // Some ops, such as XlaHostCompute op or TPU Embedding ops, require - // replica id to be added as an op attribute to be used during - // execution. Handle such ops separately and add an integer attribute - // that represents replica id. - AddReplicaIdToOpsInReplicatedRegion(builder, &replica.body(), i); + if (failed(UpdateSymbolUsesWithClones(symbol_table, module, replica.body(), + cloned_functions))) + return failure(); - // Map aliased devices to explicit devices based on replica. - if (has_devices) { - replica.walk([&](tf_device::LaunchOp launch) { - if (auto device_by_replica = devices.getValue().get(launch.device())) - launch.setAttr( - kDeviceAttr, - device_by_replica.cast()[i].cast()); - }); - } + if (failed(UpdateRegionReplicateVariantOps( + builder, replicate_op.getLoc(), replica.body(), + /*replica_id=*/i, cloned_functions, devices))) + return failure(); replicas.push_back(replica); } - return replicas; + return success(); } // Creates islands per replica from `tf_device.replicate` region and remap @@ -183,17 +370,19 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // }) {device = "/DEVICE:3"} : () -> tensor // tf_executor.yield %a1, %b1 : tensor, tensor // } -void CreateIslandsFromReplicate(const Dialect* tf_dialect, - tf_executor::GraphOp graph_op, - tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { +LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, + ModuleOp module, + tf_executor::GraphOp graph_op, + tf_executor::IslandOp island_op, + tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); - const int num_replicas = replicate_op.n().getLimitedValue(); + const int num_replicas = replicate_op.n(); // Create islands per replica. - llvm::SmallVector replicas = - ExpandReplicateIntoReplicas(tf_dialect, &builder, island_op, replicate_op, - num_replicas); + llvm::SmallVector replicas; + if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, module, island_op, + replicate_op, num_replicas, replicas))) + return failure(); // Collect all replica results. llvm::SmallVector replicas_outputs(replicate_op.getNumResults(), @@ -244,36 +433,41 @@ void CreateIslandsFromReplicate(const Dialect* tf_dialect, } island_op.erase(); + return success(); } -// Finds islands with a single `tf_device.replicate` and create individual -// islands per replica of the replicate. -void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, - tf_executor::GraphOp graph_op, - tf_executor::IslandOp island_op) { - if (!island_op.WrapsSingleOp()) return; - - if (auto replicate_op = - llvm::dyn_cast(&island_op.GetBody().front())) - CreateIslandsFromReplicate(tf_dialect, graph_op, island_op, replicate_op); -} - -void ReplicateToIslandPass::runOnFunction() { - const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); +void ReplicateToIslandPass::runOnOperation() { + auto module = getOperation(); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); if (!tf_dialect) { - signalPassFailure(); - getFunction().emitError() << "'tf' dialect is not registered"; + module.emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); } - getFunction().walk([&](tf_executor::GraphOp graph_op) { - for (auto island_op : - llvm::make_early_inc_range(graph_op.getOps())) - LowerSingleIslandReplicateToIslands(tf_dialect, graph_op, island_op); + // Find islands with a single `tf_device.replicate` and create individual + // islands per replica of the replicate. + llvm::SmallVector replicate_op_islands; + module.walk([&](tf_executor::GraphOp graph_op) { + for (auto island_op : graph_op.getOps()) { + if (!island_op.WrapsSingleOp()) continue; + + if (isa(&island_op.GetBody().front())) + replicate_op_islands.push_back(island_op); + } }); + + for (tf_executor::IslandOp island_op : replicate_op_islands) { + auto graph_op = island_op.getParentOfType(); + auto replicate_op = + cast(island_op.GetBody().front()); + if (failed(CreateIslandsFromReplicate(tf_dialect, module, graph_op, + island_op, replicate_op))) + return signalPassFailure(); + } } } // anonymous namespace -std::unique_ptr> CreateReplicateToIslandPass() { +std::unique_ptr> CreateReplicateToIslandPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index 7e8e9ee30c8..648805febfe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -26,10 +26,13 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -39,6 +42,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +#define DEBUG_TYPE "tf-resource-device-inference" namespace mlir { namespace TF { @@ -66,22 +72,18 @@ class PerFunctionResult { : alias_analysis_(alias_analysis) {} // Returns the recorded device assignment for a resource, if any. - llvm::Optional DeviceForResource( - const Value resource) const { - llvm::Optional result; - if (alias_analysis_.IsUnknownResource(resource)) return result; + Optional DeviceForResource(Value resource) const { + Optional result; + if (alias_analysis_.IsUnknownResource(resource)) return llvm::None; for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { auto it = resource_id_to_device_.find(id); if (it == resource_id_to_device_.end()) continue; - if (!result) { + if (!result || result == it->second) { result = it->getSecond(); continue; } - if (result != it->getSecond()) { - // Got conflicting assignments, clear the result. - result.reset(); - return result; - } + // Got conflicting assignments + return llvm::None; } return result; } @@ -90,7 +92,7 @@ class PerFunctionResult { // conflicts with an existing one, returns an error. // // If `changed` is provided, assign *changed to true if anything is modified. - LogicalResult AddResourceDevice(const Value resource, llvm::StringRef device, + LogicalResult AddResourceDevice(Value resource, StringRef device, bool* changed = nullptr) { if (alias_analysis_.IsUnknownResource(resource)) return success(); for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { @@ -106,13 +108,12 @@ class PerFunctionResult { } private: - llvm::SmallDenseMap resource_id_to_device_; + llvm::SmallDenseMap resource_id_to_device_; const TF::ResourceAliasAnalysis::Info& alias_analysis_; }; // Tries to record device assignment for a resource. -LogicalResult AddResourceDeviceAndEmitError(const Value resource, - llvm::StringRef device, +LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device, Operation* error_reporting_op, PerFunctionResult* result, bool* changed = nullptr) { @@ -124,18 +125,34 @@ LogicalResult AddResourceDeviceAndEmitError(const Value resource, return res; } +// Extracts and canonicalizes the device attribute. +inline StringRef GetDeviceAttr(FuncOp func, int arg_no) { + auto device_attr = + func.getArgAttrOfType(arg_no, kFuncDeviceAttr); + return device_attr ? device_attr.getValue() : ""; +} + +// Extracts and canonicalizes the device attribute. +inline StringRef GetDeviceAttr(Operation* op) { + auto device_attr = op->getAttrOfType(kDeviceAttr); + return device_attr ? device_attr.getValue() : ""; +} + +// Print operation with debug info (to get line number info for debugging) +void dump(StringRef message, Operation* op) { + llvm::dbgs() << message; + op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true)); + llvm::dbgs() << "\n"; +} + // Propagates device assignment inside a function. LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, PerFunctionResult* result) { OpBuilder builder(func_op); // Function arguments. - for (auto arg : func_op.getArguments()) { - if (!mlir::getElementTypeOrSelf(arg.getType()).isa()) { - continue; - } - auto device_attr = func_op.getArgAttrOfType( - arg.getArgNumber(), kFuncDeviceAttr); - if (!device_attr || device_attr.getValue() == "") { + for (auto arg : filter_resources(func_op.getArguments())) { + StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber()); + if (device_attr.empty()) { // If device_attr does not exist, try to construct it from any recorded // assignment. if (auto device = result->DeviceForResource(arg)) { @@ -145,51 +162,71 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, continue; } // Record the attribute. - auto res = AddResourceDeviceAndEmitError(arg, device_attr.getValue(), - func_op, result); + auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result); if (failed(res)) return res; } - auto walk_res = func_op.walk([&](Operation* op) { - if (auto var_handle = llvm::dyn_cast(op)) { - // Record VarHandleOp's device attribute. - auto device_attr = - var_handle.getAttrOfType(kDeviceAttr); - if (!device_attr || device_attr.getValue().empty()) { - return WalkResult::advance(); - } - auto res = AddResourceDeviceAndEmitError( - var_handle.resource(), device_attr.getValue(), op, result); - if (failed(res)) return WalkResult::interrupt(); - } - if (auto identity = llvm::dyn_cast(op)) { - // Try to construct IdentityOp's attribute from recorded assignment. - if (!mlir::getElementTypeOrSelf(identity.output().getType()) - .isa()) { - return WalkResult::advance(); - } - if (auto device = result->DeviceForResource(identity.output())) { - auto device_attr = - identity.getAttrOfType(kDeviceAttr); - if (!device_attr || device_attr.getValue().empty()) { - identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); + + // To support WhileRegion, we need to propagate device attributes from + // WhileRegion operands to body/cond region arguments *prior* to visiting + // these regions. Use tensorflow::walk() instead of MLIR core walker to + // implement such a pre-order walk. + auto walk_res = tensorflow::GenericWalk( + func_op, [&](Operation* op, const tensorflow::WalkStage& stage) { + // We just need to visit operations in pre-order mode. + if (!stage.IsBeforeAllRegions()) return WalkResult::advance(); + + if (auto var_handle = dyn_cast(op)) { + // Record VarHandleOp's device attribute. + StringRef device_attr = GetDeviceAttr(op); + if (device_attr.empty()) return WalkResult::advance(); + auto res = AddResourceDeviceAndEmitError(var_handle.resource(), + device_attr, op, result); + if (failed(res)) return WalkResult::interrupt(); + } else if (auto identity = dyn_cast(op)) { + LLVM_DEBUG(dump("Visiting ", identity)); + // Try to construct IdentityOp's attribute from recorded assignment. + if (!GetDeviceAttr(op).empty()) return WalkResult::advance(); + for (auto output : filter_resources(op->getResults())) { + LLVM_DEBUG(llvm::dbgs() << " Processing output #" + << output.getResultNumber() << "\n"); + if (auto device = result->DeviceForResource(output)) { + LLVM_DEBUG(llvm::dbgs() + << " Setting device = " << *device << "\n"); + identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); + } + } + } else if (auto while_region = dyn_cast(op)) { + // For WhileRegion, do local analysis prior to visiting the attached + // regions and propagate device annotations to the cond and body + // region arguments. The annotations are the union of annotations + // on the input and result. Resource alias analysis already propagates + // resource ID from the inputs to the results for a while, so just + // need to consider the results. + LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n"); + + for (auto output : filter_resources(while_region.getResults())) { + auto device = result->DeviceForResource(output); + int output_index = output.getResultNumber(); + if (!device) { + LLVM_DEBUG(llvm::dbgs() + << " No device for output #" << output_index << "\n"); + continue; + } + // Transfer the annotation to both region arguments + for (Region* region : while_region.getRegions()) { + BlockArgument arg = region->getArgument(output_index); + LLVM_DEBUG(llvm::dbgs() + << " Propagating device = '" << *device + << "' to arg #" << output_index << " of region #" + << region->getRegionNumber() << "\n"); + if (failed(AddResourceDeviceAndEmitError(arg, *device, + while_region, result))) + return WalkResult::interrupt(); + } + } } - } - return WalkResult::advance(); - } - // Propagate and record output device assignment for other ops based on - // existing recording. E.g., IdentityN. - for (auto output : op->getResults()) { - if (!mlir::getElementTypeOrSelf(output.getType()) - .isa()) { - continue; - } - if (auto device = result->DeviceForResource(output)) { - auto res = AddResourceDeviceAndEmitError(output, *device, op, result); - if (failed(res)) return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + return WalkResult::advance(); + }); return failure(walk_res.wasInterrupted()); } @@ -198,13 +235,13 @@ void ResourceDeviceInference::runOnOperation() { const auto& resource_alias_analysis = getAnalysis(); - llvm::SmallDenseMap per_function_results; + llvm::SmallDenseMap per_function_results; llvm::SetVector worklist; - module.walk([&](FuncOp func_op) { + for (auto func_op : module.getOps()) { worklist.insert(func_op); per_function_results.try_emplace( func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op)); - }); + } // Helper that propagates an op's recorded operand device assignments to its // called function's arguments. auto propagate_operands_to_callee_arguments = @@ -214,51 +251,59 @@ void ResourceDeviceInference::runOnOperation() { assert(callee); auto& callee_res = per_function_results.find(callee)->getSecond(); bool callee_needs_recompute = false; - for (auto operand_and_argument : - llvm::zip(caller_operands, callee.getArguments())) { - if (!mlir::getElementTypeOrSelf( - std::get<0>(operand_and_argument).getType()) - .isa()) { - continue; - } - auto device = - caller_res.DeviceForResource(std::get<0>(operand_and_argument)); + for (BlockArgument arg : filter_resources(callee.getArguments())) { + Value arg_operand = caller_operands[arg.getArgNumber()]; + auto device = caller_res.DeviceForResource(arg_operand); if (!device) continue; - if (failed(AddResourceDeviceAndEmitError( - std::get<1>(operand_and_argument), *device, caller, - &callee_res, &callee_needs_recompute))) { + LLVM_DEBUG(llvm::dbgs() + << "Propagating '" << *device << "' to arg #" + << arg.getArgNumber() << " of function @" + << callee.getName() << "\n"); + if (failed(AddResourceDeviceAndEmitError(arg, *device, caller, + &callee_res, + &callee_needs_recompute))) return failure(); - } } // If the callee recording is modified, make sure that it will be // reprocessed. - if (callee_needs_recompute) { - worklist.insert(callee); - } + if (callee_needs_recompute) worklist.insert(callee); } return success(); }; while (!worklist.empty()) { - auto func_op = worklist.back(); - worklist.pop_back(); + auto func_op = worklist.pop_back_val(); auto& func_res = per_function_results.find(func_op)->getSecond(); // In-function propagation. - if (failed(ComputeResourceDevicesInComputation(func_op, &func_res))) { + if (failed(ComputeResourceDevicesInComputation(func_op, &func_res))) return signalPassFailure(); - } + // Propagation to callees. auto walk_res = func_op.walk([&](Operation* op) { - if (auto while_op = llvm::dyn_cast(op)) { + if (auto while_op = dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( while_op, while_op.getOperands(), - {while_op.body_func(), while_op.cond_func()}, func_res))) - return WalkResult::interrupt(); - } else if (auto if_op = llvm::dyn_cast(op)) { - if (failed(propagate_operands_to_callee_arguments( - if_op, if_op.input(), {if_op.then_func(), if_op.else_func()}, + {while_op.body_function(), while_op.cond_function()}, func_res))) return WalkResult::interrupt(); + } else if (auto if_op = dyn_cast(op)) { + if (failed(propagate_operands_to_callee_arguments( + if_op, if_op.input(), + {if_op.then_function(), if_op.else_function()}, func_res))) + return WalkResult::interrupt(); + } else if (auto call = dyn_cast(op)) { + auto func = dyn_cast(call.resolveCallable()); + if (!func) { + op->emitError( + "Cannot propagate device attribute to callee: Unable to resolve " + "call"); + return WalkResult::interrupt(); + } + LLVM_DEBUG(llvm::dbgs() + << "Visiting call to function @" << func.getName() << "\n"); + if (failed(propagate_operands_to_callee_arguments( + call, call.getArgOperands(), {func}, func_res))) + return WalkResult::interrupt(); } return WalkResult::advance(); }); @@ -266,15 +311,15 @@ void ResourceDeviceInference::runOnOperation() { } } +PassRegistration pass( + "tf-resource-device-inference", + "Propagates the device attribute on resources from callers to callees."); + } // namespace std::unique_ptr> CreateResourceDeviceInferencePass() { return std::make_unique(); } -static PassRegistration pass( - "tf-resource-device-inference", - "Propagates the device attribute on resources from callers to callees."); - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 702455d156d..5984aafb88f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -15,12 +15,17 @@ limitations under the License. // This pass lifts resource variable operations outside of device computation. +#include #include +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -31,20 +36,24 @@ limitations under the License. #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -136,15 +145,18 @@ struct ResourceOpLiftingPass void runOnOperation() override; }; -// Removes identity nodes in the block. The device computation does not need -// such nodes to carry information. -void RemoveIdentity(Block* block) { - for (auto& op : llvm::make_early_inc_range(*block)) { - if (isa(&op)) { - op.replaceAllUsesWith(op.getOperands()); - op.erase(); - } - } +bool IsResource(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +// Get the type of the data contained in a resource. Returns null if there is +// no single type in the resource. +Type GetResourceSubtype(Value value) { + auto resource_type = + getElementTypeOrSelf(value.getType()).dyn_cast(); + auto subtypes = resource_type.getSubtypes(); + if (subtypes.size() == 1) return subtypes[0]; + return nullptr; } // Performs store-load forwarding. This effectively removes @@ -186,166 +198,448 @@ void ForwardStoreToLoad(Block* block) { } } -// Moves resource load operations with the provided `move_load` function. This -// assumes load-store forwarding has been performed on this block such that -// all loads of same resource are on its initial values. A `skip_load` functions -// is used to indicate whether a load should be skipped. If there are multiple -// loads on the same resource, only the first one will be moved, and the later -// ones will be removed and replaced with the first one. -void HoistResourceLoads( - Block* block, llvm::function_ref skip_load, - llvm::function_ref move_load) { - llvm::SmallDenseMap resource_to_read_ops; +//===----------------------------------------------------------------------===// +// RegionResourceHoister +//===----------------------------------------------------------------------===// +// Helper class to hoist resource ops out of regions attached to an op. +class RegionResourceHoister { + public: + explicit RegionResourceHoister(Operation* op) : op_(op) {} + + // Analyzes attached regions to record resources read and written. + LogicalResult Analyze(); + + // Returns all resources accessed by the regions attached the op. + auto& GetResources() { return resources_; } + + // Returns if the given value is a resouce that needs lifting. + bool Contains(Value resource) const { + return resources_.find(resource) != resources_.end(); + } + + // Drops the given resource from lifting. + void DropResource(Value resource) { + resources_.erase(resource); + written_resources_.remove(resource); + } + + // Replaces all resource loads in all regions attached to the op. + void ReplaceResourceLoads(bool read_only) { + llvm::for_each(op_->getRegions(), [&](Region& region) { + ReplaceResourceLoads(region, read_only); + }); + } + + static LogicalResult ReplaceOpWithNewOp(Operation* op); + + private: + // Returns if any resources need lifting. + bool NeedsLifting() const { return !resources_.empty(); } + + // Returns the number of results generated by the lifted op. + int GetLiftedNumResults() const { return num_new_results_; } + + // Generates hoisted reads for resources that need them before the op. + void GenerateHoistedReads(); + + // Replaces all resource loads in the given region with hoisted loads. If + // `read_only` is true, limit this replacement to read only resources. + void ReplaceResourceLoads(Region& region, bool read_only); + + // Appends final values writte to resources to the region returns for the + // given set of regions. + void AppendResourceStoreValueToReturn(RegionRange regions); + + // Performs the final replacement of the op. + void ReplaceOpWithNewOp(); + + // Returns is this resource was written to in any of the regions. + bool IsWritten(Value resource) const { + return written_resources_.contains(resource); + } + + static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op); + static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op); + + Operation* op_; + + // Per resource information about accesses to that resource. + struct ResourceInfo { + // Is this resource read in any of the regions? + bool is_read; + // Is this resource written in any of the regions? + bool is_written; + // Is this resource written in all of the regions? + bool is_written_all; + // The hoisted read used to replace region reads. + Value hoisted_read; + // the type of the data held by the resource. + Type data_type; + // For written resources, the result # of the lifted op which will hold the + // value of the resource. This result will be used to generates writes to + // the resource after the lifted op. + int result_index; + // Attributes on the read operation. + DictionaryAttr read_attrs; + // Attributes on the write operation. + DictionaryAttr write_attrs; + + ResourceInfo() + : is_read(false), + is_written(false), + is_written_all(false), + hoisted_read(nullptr), + data_type(nullptr), + result_index(-1) {} + + bool IsResultIndexAssigned() { return result_index != -1; } + + // Refine the resource type using the given type `type`. + void RefineType(Type type) { + if (!data_type) { + data_type = type; + } else { + data_type = TF::GetCastCompatibleType(data_type, type, + /*may_ignore_ref_type_a=*/false); + assert(data_type != nullptr && "Resource used with incompatible types"); + } + } + }; + llvm::MapVector resources_; + llvm::SetVector written_resources_; + // number of new results after lifting. + int num_new_results_; +}; + +// Analyzes resources that are read or written within attached regions. +LogicalResult RegionResourceHoister::Analyze() { + // Hoisting of child regions might have created opportunity for store-load + // forwarding. + for (Region& region : op_->getRegions()) { + ForwardStoreToLoad(®ion.front()); + } + + llvm::SetVector all_resources; + bool is_func = false; + // For functions, the resources to analyze are the function arguments. + // Otherwise, its the region captures. + if (FuncOp func = dyn_cast(op_)) { + is_func = true; + Region& body = func.getBody(); + for (BlockArgument arg : body.getArguments()) { + if (IsResource(arg)) all_resources.insert(arg); + } + } else { + getUsedValuesDefinedAbove(op_->getRegions(), all_resources); + all_resources.remove_if([](Value value) { return !IsResource(value); }); + } + + num_new_results_ = op_->getNumResults(); + + for (auto resource : all_resources) { + ResourceInfo info; + info.data_type = GetResourceSubtype(resource); + llvm::BitVector written_regions(op_->getNumRegions()); + bool unsupported_use = false; + for (OpOperand& use : resource.getUses()) { + Operation* user = use.getOwner(); + // If the user is not in one of the regions, we are not interested in it. + // Since all the sub-regions within this region (i.e., regions attached to + // op's in this region) have themselves gone through lifting, all resource + // users are expected to be operations in this region and and not embedded + // within other sub-regions attached to op's in this region. So the check + // for whether a user is in one of the regions attached to this op is + // straightforward. + if (user->getParentRegion()->getParentOp() != op_) continue; + + // For functions, if the resource is used as a return operand, use that + // as its result index. + if (is_func && isa(user)) { + assert(!info.IsResultIndexAssigned() && + "Expect resource argument to returned no more than once"); + info.result_index = use.getOperandNumber(); + continue; + } + + auto read = dyn_cast(user); + auto write = dyn_cast(user); + if (!read && !write) { + unsupported_use = true; + break; + } + + if (read && !info.is_read) { + info.is_read = true; + info.RefineType(read.value().getType()); + info.read_attrs = user->getAttrDictionary(); + } + + if (write) { + info.is_written = true; + info.RefineType(write.value().getType()); + info.write_attrs = user->getAttrDictionary(); + written_regions.set(user->getParentRegion()->getRegionNumber()); + } + } + + // If the resource is used in an op that we do not understand, skip + // lifting for that resource. + if (unsupported_use) continue; + + info.is_written_all = written_regions.count() == op_->getNumRegions(); + + // If the resource is written in some but not all regions, we would need + // a read for the value before these regions. Note that this is applicable + // only to multi-region ops: + // If/Case: If not all regions write to the resource, post hoisting the read + // value need to be routed through all paths that don't write. + // While: since while condition cannot write, any resource written in the + // while body will need to be read as well in case the while body is never + // executed. + // Both cases are handled by the condition below. + if (info.is_written && !info.is_written_all) info.is_read = true; + + // Allocate a result index for written resources that don't have one. + if (info.is_written) { + written_resources_.insert(resource); + if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++; + } + + resources_.insert({resource, info}); + } + return success(); +} + +// Generates hoisted reads for all resources that need them just before the op. +void RegionResourceHoister::GenerateHoistedReads() { + OpBuilder builder(op_); + for (auto& resource_it : GetResources()) { + Value resource = resource_it.first; + auto& info = resource_it.second; + + if (info.is_read) { + Operation* read = builder.create( + op_->getLoc(), info.data_type, resource); + read->setAttrs(info.read_attrs); + info.hoisted_read = read->getResult(0); + } + } +} + +// Replaces all resource reads with the hoisted read. +void RegionResourceHoister::ReplaceResourceLoads(Region& region, + bool read_only) { + assert(llvm::hasSingleElement(region) && "Expected single block region"); // Only iterate through ops directly in the body as we can't handle // ops nested deeper in regions. - for (Operation& op : llvm::make_early_inc_range(*block)) { - auto read_variable_op = dyn_cast(&op); - if (!read_variable_op) continue; - if (skip_load(read_variable_op)) continue; + auto all_reads = region.front().getOps(); + for (auto read_op : llvm::make_early_inc_range(all_reads)) { + Value resource = read_op.resource(); + if (!Contains(resource)) continue; - Value resource = read_variable_op.resource(); - auto p = resource_to_read_ops.insert({resource, read_variable_op}); - if (p.second) { - move_load(read_variable_op); - continue; + ResourceInfo& info = resources_[resource]; + // If replacing loads for read only resources, skip if the resource + // was written to. + if (read_only && info.is_written) continue; + + read_op.replaceAllUsesWith(info.hoisted_read); + read_op.erase(); + } +} + +// For written resources, add its value at the end of each region to that +// regions return value. For a region, its value at the end may be a value +// written to that resource in that region, or its hoisted read value if the +// resource is not written in that region. The return value can be vended out +// either as an existing return value, or a newly allocated return value. +void RegionResourceHoister::AppendResourceStoreValueToReturn( + RegionRange regions) { + for (Region* region : regions) { + assert(llvm::hasSingleElement(*region) && "Expected single block region"); + Block& front = region->front(); + auto old_return = front.getTerminator(); + assert(old_return->getNumOperands() == op_->getNumResults()); + auto new_return_operands = llvm::to_vector<4>(old_return->getOperands()); + new_return_operands.resize(num_new_results_); + + // initialize return values for written resources to be the hosited reads. + for (Value resource : written_resources_) { + const ResourceInfo& info = resources_[resource]; + new_return_operands[info.result_index] = info.hoisted_read; } - // Getting here means a load operation of this resource has been hoisted out - // before. Use hoisted load result to replace all uses of current op result - // and erase op. - op.replaceAllUsesWith(p.first->second); - op.erase(); - } -} + // Only iterate through ops directly in the body as op's embedded in child + // regions should have been lifted out. + auto assign_ops = front.getOps(); + for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { + Value resource = assign_variable_op.resource(); + if (!IsWritten(resource)) continue; -// If there are any stores to resource defined outside of the block then the -// stored values must be returned so that new values can be used by sunk -// resource stores. -// Returns true if any resource variable stored values are appended, otherwise -// false. -bool AppendResourceStoreValueToReturn(Block* body) { - bool has_resource_store = false; - auto old_return = body->getTerminator(); - - llvm::SmallVector new_return_operands(old_return->getOperands()); - - // Only iterate through ops directly in the body as we can't handle ops nested - // deeper in regions. - for (auto assign_variable_op : body->getOps()) { - Value resource = assign_variable_op.resource(); - if (!resource) continue; - - // Skip resources created inside of the body. - if (resource.getParentRegion() == body->getParent()) continue; - - // TODO(ycao): Prevent same value from being returned multiple times. - // TODO(ycao): Do not return resource store value if it is defined outside - // of cluster. - new_return_operands.push_back(assign_variable_op.value()); - has_resource_store = true; - } - - // If no resource stores are found, no need to update return op. - if (!has_resource_store) return false; - - OpBuilder builder(old_return); - builder.create(old_return->getLoc(), - new_return_operands); - old_return->erase(); - return true; -} - -// Moves resource store operations to after cluster. This assumes load-store -// forwarding has been performed on this cluster such that there is at most one -// resource store operation carrying its final value. -tf_device::ClusterOp SinkResourceStores(tf_device::ClusterOp cluster, - OpBuilder* builder) { - // Update ReturnOp inside cluster's body to output final values of updated - // external resources. - if (!AppendResourceStoreValueToReturn(&cluster.GetBody())) return cluster; - - auto new_return_op = cluster.GetBody().getTerminator(); - llvm::SmallVector new_return_types(new_return_op->getOperandTypes()); - - builder->setInsertionPoint(cluster); - auto new_cluster = builder->create( - cluster.getLoc(), new_return_types, - /*operands=*/llvm::SmallVector(), cluster.getAttrs()); - new_cluster.body().takeBody(cluster.body()); - - // Replace uses of old cluster results with those of new_cluster. - for (auto result : llvm::zip(cluster.getResults(), new_cluster.getResults())) - std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - - // Create a mapping from operands of new_return_op operands to new_cluster - // results. - BlockAndValueMapping mapper; - for (auto operand_result : - llvm::zip(new_return_op->getOperands(), new_cluster.getResults())) - mapper.map(std::get<0>(operand_result), std::get<1>(operand_result)); - - // Clone all resource store ops and map their operands to values returned from - // new_cluster. - for (Operation& op : llvm::make_early_inc_range(new_cluster.GetBody())) { - if (isa(op)) { - builder->clone(op, mapper); - op.erase(); + // TODO(ycao): Prevent same value from being returned multiple times. + // TODO(ycao): Do not return resource store value if it is defined outside + // of cluster. Both of these can be post-resource-op-lifting cleanup + // passes. + int result_index = resources_[resource].result_index; + new_return_operands[result_index] = assign_variable_op.value(); + assign_variable_op.erase(); } + old_return->setOperands(new_return_operands); } - - cluster.erase(); - return new_cluster; } -// Hoists resource variable loads and sinks stores from cluster. -LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, - ModuleOp module) { - OpBuilder builder(module); +// Replace the old op with a new op (with potentially additional results), and +// add stores to written resources after the new op. +void RegionResourceHoister::ReplaceOpWithNewOp() { + auto new_result_types = llvm::to_vector<4>(op_->getResultTypes()); + int result_region = isa(op_) ? 1 : 0; + Operation* terminator = op_->getRegion(result_region).front().getTerminator(); + auto extra_result_types = + terminator->getOperands().drop_front(op_->getNumResults()).getTypes(); + new_result_types.insert(new_result_types.end(), extra_result_types.begin(), + extra_result_types.end()); + OpBuilder builder(op_); + // Clone ths old operation but with new result types. + Operation* new_op = Operation::create( + op_->getLoc(), op_->getName(), new_result_types, + llvm::to_vector<4>(op_->getOperands()), op_->getAttrs(), + llvm::to_vector<4>(op_->getSuccessors()), op_->getNumRegions()); + builder.insert(new_op); - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&cluster.GetBody()); - - // Perform store-load forwarding. So that each resource is only loaded with - // its initial value and is only stored with its final value. - ForwardStoreToLoad(&cluster.GetBody()); - - // Move loads of external resources, if any, to before cluster. - // (Skipping resources created inside of cluster.) - HoistResourceLoads( - &cluster.GetBody(), - /*skip_load=*/ - [&](TF::ReadVariableOp read) { - return read.resource().getParentRegion() == &cluster.body(); - }, - /*move_load=*/ - [&](TF::ReadVariableOp read) { - read.getOperation()->moveBefore(cluster); - }); - - // Move stores of external resources, if any, to after cluster. - auto new_cluster = SinkResourceStores(cluster, &builder); - - llvm::SetVector captured_values; - getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), - captured_values); - - for (Value v : captured_values) { - auto tensor_type = v.getType().dyn_cast(); - if (!tensor_type) continue; - if (!tensor_type.getElementType().isa()) continue; - - return new_cluster.emitOpError() - << "has remaining resource inputs that can not be lifted"; + // Move regions to the new op. + for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) { + Region& old_region = std::get<0>(it); + Region& new_region = std::get<1>(it); + new_region.takeBody(old_region); } + // Insert stores to all written resources. + for (Value resource : written_resources_) { + ResourceInfo& info = resources_[resource]; + Value value_to_write = new_op->getResult(info.result_index); + Operation* write = builder.create( + op_->getLoc(), resource, value_to_write); + write->setAttrs(info.write_attrs); + } + + // As a part of lifting, we either reuse an existing slot for resource type + // results or add a new slot. Resource type results should not have any uses + // to begin with. So we can safely replace each old op result with the + // corresponding new op result. + int old_num_results = op_->getNumResults(); + op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results)); + op_->erase(); + op_ = nullptr; +} + +// Lift resource load and stores out of regions attached to `op`, where op is +// an If/case/cluster op. +LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster( + Operation* op) { + RegionResourceHoister hoister(op); + if (failed(hoister.Analyze())) return failure(); + + // If there are no resource region captures, then nothing to do. + if (!hoister.NeedsLifting()) return success(); + + // Start the transformation. For each region, replace the resource read with + // the value read before the op. + hoister.GenerateHoistedReads(); + hoister.ReplaceResourceLoads(/*read_only=*/false); + hoister.AppendResourceStoreValueToReturn(op->getRegions()); + hoister.ReplaceOpWithNewOp(); return success(); } +// Lift resource loads and stores out of WhileRegion +LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion( + TF::WhileRegionOp op) { + // For WhileRegion, post canonicalization all resource used within the + // body and condition regions are replaced with captured values, so we do not + // need to take into account the body and condition region arguments. + RegionResourceHoister hoister(op); + + if (failed(hoister.Analyze())) return failure(); + + // If there are no resource region captures, then nothing to do. + if (!hoister.NeedsLifting()) return success(); + + // The resources captured for While loop fall into two categories: + // (a) read-only. These reads can be replaced by a hoisted read created + // before the WhileOp (similar to if and case). + // (b) written: since the value is written in the loop (which can only in + // loop body, all these will become loop variables. Since all resource + // variables are removed from the loop variabled during + // canonicalizationW, we need to create new operand/result slots. The + // input operands for these slots are the read values + // prior to the op, and all references to these are replaced by the + // corresponding slot argument. We need to generate writes following + // the while for these resources. + // + // Note that for WhileRegion ops, if a resource is written, it will be written + // only in the body and not the condition, so the hoister analysis will infer + // it as needing a read as well. + + // Generate hoisted reads before the while. + hoister.GenerateHoistedReads(); + + // Replace just the read-only resources with the hoisted reads. + hoister.ReplaceResourceLoads(/*read_only=*/true); + + // For written resources, add additional operands to the while op. + int num_old_results = op.getNumResults(); + int num_new_results = hoister.GetLiftedNumResults(); + int num_extra_results = num_new_results - num_old_results; + + SmallVector new_result_types; + SmallVector new_while_operands; + new_result_types.resize(num_extra_results); + new_while_operands.resize(num_extra_results); + + for (auto& it : hoister.GetResources()) { + if (!it.second.is_written) continue; + int index = it.second.result_index - num_old_results; + new_result_types[index] = it.second.data_type; + new_while_operands[index] = it.second.hoisted_read; + } + op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands); + + // Patch the cond and body regions to have additional arguments, and replace + // the remaining resource reads (which will be resource reads for written + // resources) with these arguments. + for (Region* region : op.getRegions()) { + region->addArguments(new_result_types); + // Point hoisted read for written resources to the region's arguments. + for (auto& it : hoister.GetResources()) { + if (!it.second.is_written) continue; + it.second.hoisted_read = region->getArgument(it.second.result_index); + } + hoister.ReplaceResourceLoads(*region, /*read_only=*/false); + } + + // Add additional return values to body return. These correspond to values + // written to resources in the body region. + hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front()); + + // Finally, create a new while with additional return values. + hoister.ReplaceOpWithNewOp(); + return success(); +} + +// Lift resources out of the regions attached to `op` +LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) { + if (auto while_op = dyn_cast(op)) + return HoistResourcesOutOfWhileRegion(while_op); + return HoistResourcesOutOfIfCaseCluster(op); +} + // Holds information about a function's use of a resource argument. struct ResourceArgUseInfo { + // Data type of the data contained in the resource. Type data_type; + // Is the resource argument used in an assign op? bool updated; + // Is the resource argument used in a read or assign op? bool used; }; @@ -356,34 +650,35 @@ struct ResourceArgUseInfo { LogicalResult FindResourceArgUseInfo( FuncOp func_op, llvm::SmallDenseMap* result) { auto return_op = func_op.front().getTerminator(); - for (auto arg : func_op.getArguments()) { - if (!getElementTypeOrSelf(arg.getType()).isa()) continue; + for (auto arg : TF::filter_resources(func_op.getArguments())) { ResourceArgUseInfo info; info.used = false; info.updated = false; - bool do_not_touch = false; + bool read_or_assigned = false; + bool used_in_unsupported_op = false; for (auto user : arg.getUsers()) { if (user == return_op) continue; + info.used = true; if (auto read = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.data_type = read.getType(); continue; } + if (auto assign = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.updated = true; info.data_type = assign.value().getType(); continue; } - if (isa(user)) { - // Stacks will be handled by a separate pass. - do_not_touch = true; - break; - } - user->emitOpError("found unsupported operations on resource."); - return failure(); + + used_in_unsupported_op = true; + break; } - if (!do_not_touch) (*result)[arg.getArgNumber()] = info; + + // If the arg is used in an unsupported op, skip lifting it. + if (used_in_unsupported_op) continue; + (*result)[arg.getArgNumber()] = info; } return success(); } @@ -469,59 +764,61 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals( // signature. resource_data_types is the (index, data type) pair for each // resource argument. handle_updated_arg_value is a caller-provided function // that handles the updated value for an resource argument. -void LiftArgRetResourcesForFunction( +LogicalResult LiftArgRetResourcesForFunction( FuncOp func_op, const llvm::SmallDenseMap& resource_data_types, llvm::function_ref handle_updated_arg_value) { ForwardStoreToLoad(&func_op.front()); - // Maps a resource argument to the first read. - llvm::SmallDenseMap resource_arg_read; - // Maps a resource argument to the last write. - llvm::SmallDenseMap resource_arg_write; - // Use HoistResourceLoads to CSE loads and the `move_load` function only - // records the remaining load to resource_arg_read. - HoistResourceLoads( - &func_op.front(), - /*skip_load=*/ - [&](TF::ReadVariableOp read) { - return !read.resource().isa(); - }, - /*move_load=*/ - [&](TF::ReadVariableOp read) { - resource_arg_read[read.resource()] = read; - }); - // Record the stores in resource_arg_read. - for (auto& op : llvm::make_early_inc_range(func_op.front())) { - auto write = llvm::dyn_cast(&op); - if (!write) continue; - auto arg = write.resource().dyn_cast(); - if (!arg) continue; - // After ForwardStoreToLoad(), there should be just one store for each - // resource. - resource_arg_write[arg] = write; - } - // Now change the input types to non-resource and remove the internal loads. - auto new_types = llvm::to_vector<8>(func_op.getType().getInputs()); - for (auto& entry : resource_data_types) { - auto arg = func_op.getArgument(entry.getFirst()); - auto read_it = resource_arg_read.find(arg); - auto write_it = resource_arg_write.find(arg); - arg.setType(entry.getSecond()); - new_types[arg.getArgNumber()] = entry.getSecond(); - if (read_it != resource_arg_read.end()) { - read_it->getSecond().replaceAllUsesWith(arg); - read_it->getSecond().erase(); - } - if (write_it != resource_arg_write.end()) { - handle_updated_arg_value(arg.getArgNumber(), - write_it->getSecond().value()); - write_it->getSecond().erase(); + + RegionResourceHoister hoister(func_op); + if (failed(hoister.Analyze())) return failure(); + + // Each of these resources could be read or written in the function. If its + // read, we need to replace the resource arg with a value arg to get the + // read value. If its written, we need to replace the write with an additional + // value to be written. + + // Now create read values that will be used to replace each resource that + // is read in the function body. These read vaulues are just the same argument + // with type replaced. + llvm::SmallVector skipped_args; + for (auto& it : hoister.GetResources()) { + BlockArgument arg = it.first.dyn_cast(); + assert(arg && "Expect resources for FuncOp to be its arguments"); + auto type_iter = resource_data_types.find(arg.getArgNumber()); + if (type_iter == resource_data_types.end()) { + // Skip lifting the resource if it's not present in the data type map. + // This indicates that the resource is not to be lifted because it is used + // in an unsupported op in some other function. + skipped_args.push_back(arg); + } else { + arg.setType(type_iter->second); + it.second.hoisted_read = arg; } } - func_op.setType(FunctionType::get( - new_types, - llvm::to_vector<4>(func_op.front().getTerminator()->getOperandTypes()), - func_op.getContext())); + + // Drop all the args that have to be skipped. + for (Value arg : skipped_args) hoister.DropResource(arg); + + hoister.ReplaceResourceLoads(/*read_only=*/false); + + // For writes, invoke the callback and then erase the write. + auto assign_ops = func_op.front().getOps(); + for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { + Value resource = assign_variable_op.resource(); + if (!hoister.Contains(resource)) continue; + + auto arg = resource.dyn_cast(); + handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value()); + assign_variable_op.erase(); + } + + func_op.setType( + FunctionType::get(func_op.front().getArgumentTypes(), + func_op.front().getTerminator()->getOperandTypes(), + func_op.getContext())); + + return success(); } // Returns a vector filtered from range where the unused elements (specified by @@ -570,29 +867,7 @@ void AddLoadsStoresOutsideControlFlowOp( // Lifts loads/stores from while loop's body and cond functions. LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&body.front()); - RemoveIdentity(&cond.front()); auto return_op = body.front().getTerminator(); - // Sanity check: body resource input/output should alias each other. - for (auto arg : body.getArguments()) { - if (!getElementTypeOrSelf(arg.getType()).isa()) continue; - if (return_op->getOperand(arg.getArgNumber()) != arg) { - return return_op->emitOpError( - "resource used in while loop is only supported when the ") - << "resource input and output alias each other in the loop body."; - } - } - // FindResourceArgUseInfo will check supported resource ops (read and assign), - // but loop condition has additional requirement that it cannot write - // resources. - if (cond.walk([&](TF::AssignVariableOp assign) { - assign.emitOpError("found resource write in loop condition."); - return WalkResult::interrupt(); - }) - .wasInterrupted()) { - return failure(); - } llvm::SmallDenseMap body_use_info; llvm::SmallDenseMap cond_use_info; if (failed(FindResourceArgUseInfo(body, &body_use_info)) || @@ -603,12 +878,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { auto resource_arg_uses = MergeArgResourceUseInfo(body_use_info, cond_use_info); if (resource_arg_uses.empty()) return success(); - for (const auto& entry : resource_arg_uses) { - // Replace output resource uses with the input, so that we can later freely - // change the output type. - while_op.getResult(entry.getFirst()) - .replaceAllUsesWith(while_op.getOperand(entry.getFirst())); - } + // Remove unused resources in functions. llvm::SmallVector old_to_new_indices; llvm::SmallDenseMap remaining_resource_data_types; @@ -661,50 +931,8 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { // Lifts loads/stores from an IfOp or CaseOp's branches. template LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { - // Remove identity nodes to avoid aliasing. - for (auto func : branches) RemoveIdentity(&func.front()); - - // Sanity check: branch return of resources should be aliases of inputs. If - // so, replace the output uses with the input so that we can remove these - // outputs. - for (OpResult result : op.getResults()) { - if (!getElementTypeOrSelf(result.getType()).isa()) - continue; - unsigned result_index = result.getResultNumber(); - constexpr unsigned kUnassigned = -1; - unsigned common_aliasing_arg_num = kUnassigned; - for (auto func : branches) { - auto retval = func.front().getTerminator()->getOperand(result_index); - assert(result.getType() == retval.getType()); - auto aliasing_arg = retval.dyn_cast(); - if (!aliasing_arg) - return op.emitOpError("unsupported output: ") - << "resource does not alias input"; - if (common_aliasing_arg_num == kUnassigned) - common_aliasing_arg_num = aliasing_arg.getArgNumber(); - if (aliasing_arg.getArgNumber() != common_aliasing_arg_num) - return op.emitOpError("unsupported output: ") - << "resource does not alias a single input"; - } - assert(common_aliasing_arg_num != kUnassigned); - result.replaceAllUsesWith(op.getOperand(common_aliasing_arg_num + 1)); - } - - // Erase the resource outputs from the branches. - int64_t non_resource_results = 0; - llvm::SmallVector old_to_new_output_indices; - bool output_removed = false; - for (auto result : op.getResults()) { - if (!getElementTypeOrSelf(result.getType()) - .template isa()) { - old_to_new_output_indices.push_back(non_resource_results++); - continue; - } - old_to_new_output_indices.push_back(-1); - for (auto func : branches) - func.front().getTerminator()->eraseOperand(non_resource_results); - output_removed = true; - } + // For canonicalized If/Case, there should not be any resource outputs + int64_t non_resource_results = op.getNumResults(); llvm::SmallDenseMap resource_arg_uses; if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses))) @@ -719,7 +947,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { MergeArgResourceUseInfo(resource_arg_uses, branch_use_info); } - if (resource_arg_uses.empty() && !output_removed) return success(); + if (resource_arg_uses.empty()) return success(); // Remove unused resources in functions. llvm::SmallDenseMap remaining_resource_data_types; RemoveUnusedResourceArgumentsAndForwardedRetvals( @@ -794,12 +1022,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { AddLoadsStoresOutsideControlFlowOp(new_op, arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0, end = old_to_new_output_indices.size(); i < end; ++i) { - if (old_to_new_output_indices[i] >= 0) { - op.getResult(i).replaceAllUsesWith( - new_op.getResult(old_to_new_output_indices[i])); - } - } + op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults())); op.erase(); return success(); } @@ -825,8 +1048,6 @@ struct PartitionedCallLiftingInfo { // happens on a clone, which will be stored in `result`. LogicalResult HandlePartitionedCallOpCallee( FuncOp callee, PartitionedCallLiftingInfo* result) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(&callee.front()); // Sanity check: return of resources should be aliases of inputs. Such outputs // will be removed later. int64_t non_resource_results = 0; @@ -914,8 +1135,8 @@ LogicalResult HandlePartitionedCallOpCallee( // resource-lifted new callee function in lifting_info. template void UpdatePartitionedCallOpWithNewCallee( - CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) { - if (lifting_info.lifted_callee == nullptr) return; + CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) { + if (!lifting_info.lifted_callee) return; // Replace output resource uses with the aliasing input, so that we can remove // this output. for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) { @@ -929,12 +1150,10 @@ void UpdatePartitionedCallOpWithNewCallee( auto new_operands = FilterRange(call_op.args(), lifting_info.use_info); auto new_call = builder.create( - call_op.getLoc(), - const_cast(lifting_info.lifted_callee).getType().getResults(), + call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(), new_operands, call_op.getAttrs()); new_call.setAttr( - "f", builder.getSymbolRefAttr( - const_cast(lifting_info.lifted_callee).getName())); + "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName())); AddLoadsStoresOutsideControlFlowOp( new_call, lifting_info.arg_data_type_and_updated_output_index); // Replace uses. @@ -948,8 +1167,9 @@ void UpdatePartitionedCallOpWithNewCallee( call_op.erase(); } -LogicalResult HoistForFunctionalControlFlow( - Block*, ModuleOp, llvm::SmallDenseMap*); +LogicalResult HoistForControlFlow( + Block*, ModuleOp, + llvm::SmallDenseMap*); // A templated routine for handling both PartitionedCallOp and // StatefulPartitionedCallOp. If the callee is already lifted, it just updates @@ -958,12 +1178,15 @@ LogicalResult HoistForFunctionalControlFlow( template LogicalResult HandlePartitionedCallOp( CallOpType call_op, FuncOp callee, ModuleOp module, - llvm::SmallDenseMap* lifted_callees) { - auto emplace_res = - lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo()); + llvm::SmallDenseMap* + lifted_callees) { + auto emplace_res = lifted_callees->try_emplace(callee.getName(), + PartitionedCallLiftingInfo()); if (emplace_res.second) { // Unseen callee. Perform resource lifting on it. - HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees); + if (failed(HoistForControlFlow(&callee.front(), module, lifted_callees))) + return failure(); + if (failed(HandlePartitionedCallOpCallee( callee, &emplace_res.first->getSecond()))) { return failure(); @@ -975,30 +1198,28 @@ LogicalResult HandlePartitionedCallOp( // Hoists resource loads/stores from control flow ops in `block` outside the // body/cond/branch/callee functions. -LogicalResult HoistForFunctionalControlFlow( +LogicalResult HoistForControlFlow( Block* block, ModuleOp module, - llvm::SmallDenseMap* + llvm::SmallDenseMap* lifted_partitioned_call_callees) { - // Remove identity nodes to avoid aliasing. - RemoveIdentity(block); for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { - auto body = while_op.body_func(); - auto cond = while_op.cond_func(); + auto body = while_op.body_function(); + auto cond = while_op.cond_function(); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&body.front(), module, - lifted_partitioned_call_callees); - HoistForFunctionalControlFlow(&cond.front(), module, - lifted_partitioned_call_callees); + HoistForControlFlow(&body.front(), module, + lifted_partitioned_call_callees); + HoistForControlFlow(&cond.front(), module, + lifted_partitioned_call_callees); if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&then_branch.front(), module, - lifted_partitioned_call_callees); - HoistForFunctionalControlFlow(&else_branch.front(), module, - lifted_partitioned_call_callees); + HoistForControlFlow(&then_branch.front(), module, + lifted_partitioned_call_callees); + HoistForControlFlow(&else_branch.front(), module, + lifted_partitioned_call_callees); if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}))) return failure(); } else if (auto case_op = llvm::dyn_cast(&op)) { @@ -1008,16 +1229,17 @@ LogicalResult HoistForFunctionalControlFlow( FuncOp func = module.lookupSymbol(branch.cast()); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&func.front(), module, - lifted_partitioned_call_callees); + HoistForControlFlow(&func.front(), module, + lifted_partitioned_call_callees); branch_functions.push_back(func); } if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); } else if (auto call_op = llvm::dyn_cast(&op)) { auto callee = call_op.func(); - if (!callee) + if (!callee) { return call_op.emitOpError( "resource lifting does not support call with nested references."); + } if (failed(HandlePartitionedCallOp(call_op, callee, module, lifted_partitioned_call_callees))) { // Nested control flow handling is done in HandlePartitionedCallOp(). @@ -1029,26 +1251,19 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees))) { return failure(); } + } else if (isa(op)) { + for (Region& region : op.getRegions()) + HoistForControlFlow(®ion.front(), module, + lifted_partitioned_call_callees); + LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op); + if (failed(result)) return failure(); } } - // Remove unused local variables. + // After we have hoisted operations in the block, we may have added new read + // and writes of resources to this block. Clean them up by doing store-load + // forwarding. ForwardStoreToLoad(block); - llvm::SmallVector local_vars; - for (Operation& op : *block) { - if (auto local_var = llvm::dyn_cast(&op)) { - local_vars.push_back(local_var); - } - } - for (auto local_var : local_vars) { - if (llvm::all_of(local_var.resource().getUsers(), - [](const Operation* user) { - return isa(user); - })) { - for (auto user : local_var.resource().getUsers()) user->erase(); - local_var.erase(); - } - } return success(); } @@ -1056,22 +1271,25 @@ LogicalResult HoistForFunctionalControlFlow( // Returns failure if there are remaining resource-type values that can not be // lifted. void ResourceOpLiftingPass::runOnOperation() { - llvm::SmallDenseMap + llvm::SmallDenseMap lifted_partitioned_call_callees; ModuleOp module = getOperation(); - auto result = module.walk([&](FuncOp func_op) { + + if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module))) + return signalPassFailure(); + + auto walk_result = module.walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::ClusterOp cluster) { - if (failed(HoistForFunctionalControlFlow( - &cluster.GetBody(), module, &lifted_partitioned_call_callees)) || - failed(HoistResourceOpsFromCluster(cluster, module))) { - return WalkResult::interrupt(); - } + LogicalResult result = HoistForControlFlow( + &cluster.GetBody(), module, &lifted_partitioned_call_callees); + if (failed(result)) return WalkResult::interrupt(); + result = RegionResourceHoister::ReplaceOpWithNewOp(cluster); + if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); }); }); - if (result.wasInterrupted()) { - signalPassFailure(); - } + + if (walk_result.wasInterrupted()) return signalPassFailure(); } struct ResourceOpLiftingForMainFunctionPass @@ -1121,11 +1339,14 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { << function.getBlocks().size(); } - llvm::SmallDenseMap + if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function))) + return failure(); + + llvm::SmallDenseMap lifted_partitioned_call_callees; - return HoistForFunctionalControlFlow(&function.front(), - cast(function.getParentOp()), - &lifted_partitioned_call_callees); + return HoistForControlFlow(&function.front(), + cast(function.getParentOp()), + &lifted_partitioned_call_callees); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc new file mode 100644 index 00000000000..97030595c99 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -0,0 +1,464 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" + +#include "llvm/ADT/BitVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace { + +bool IsResource(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +// Removes identity nodes in the block. The device computation does not need +// such nodes to carry information. +void RemoveIdentity(Block &block) { + for (auto &op : llvm::make_early_inc_range(block)) { + if (isa(&op)) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } + } +} + +// Eliminate local variables that are only assigned to but never read, and thus +// are dead. +void RemoveDeadLocalVariables(Block &block) { + llvm::SmallVector local_vars; + for (Operation &op : block) { + if (auto local_var = llvm::dyn_cast(&op)) { + local_vars.push_back(local_var); + } + } + for (auto local_var : local_vars) { + if (llvm::all_of(local_var.resource().getUsers(), + [](const Operation *user) { + return isa(user); + })) { + for (auto user : local_var.resource().getUsers()) user->erase(); + local_var.erase(); + } + } +} + +LogicalResult CleanupAndCanonicalize(Operation *parent_op); + +// Eliminates unusued results from an operation `op` by cloning it with reduced +// result types and doing appropriate use replacements. `results_to_eliminate` +// is a bitvector of result positions to eliminate. If its null, then all unused +// results of the operation will be eliminated. +void EliminateUnusedResults( + Operation *op, const llvm::BitVector *results_to_eliminate = nullptr) { + auto can_eliminate = [&](OpResult &result) -> bool { + if (!result.use_empty()) return false; + if (results_to_eliminate) + return results_to_eliminate->test(result.getResultNumber()); + else + return true; + }; + SmallVector new_result_types; + for (OpResult result : op->getResults()) { + if (can_eliminate(result)) continue; + new_result_types.push_back(result.getType()); + } + + // Rebuild the new operation with lesser number of results. + OpBuilder builder(op); + Operation *new_op = Operation::create( + op->getLoc(), op->getName(), new_result_types, + llvm::to_vector<4>(op->getOperands()), op->getAttrs(), + llvm::to_vector<4>(op->getSuccessors()), op->getNumRegions()); + builder.insert(new_op); + + // Move region bodies to the new operation. + for (auto it : llvm::zip(op->getRegions(), new_op->getRegions())) { + Region &old_region = std::get<0>(it); + Region &new_region = std::get<1>(it); + new_region.takeBody(old_region); + } + + // Replace used results and erase the old op. + int next_result_idx = 0; + for (OpResult result : op->getResults()) { + if (can_eliminate(result)) continue; + result.replaceAllUsesWith(new_op->getResult(next_result_idx++)); + } + op->erase(); +} + +// Clones a function if it cannot be patched in place. Clone if there are +// multiple uses or unknown uses (for external functions). The cloned function +// will be marked as private. +FuncOp CloneFunctionIfNeeded(FuncOp func) { + ModuleOp module = func.getParentOfType(); + auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue())) + return func; + FuncOp cloned = func.clone(); + cloned.setVisibility(SymbolTable::Visibility::Private); + cloned.setName(func.getName().str() + "_lifted"); + SymbolTable(module).insert(cloned); + return cloned; +} + +// Eliminates unused results for If/Case operations. Also patches up the +// branch functions to (a) drop the ununsed return values, and (b) as a result +// if some argument becomes unused in all branches, drop that argument and the +// corresponding if/case input operand. +void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef branches) { + // Clone branch functions if needed since we will be mutating them. + SmallVector cloned_branches; + cloned_branches.reserve(branches.size()); + for (FuncOp func : branches) { + FuncOp cloned = CloneFunctionIfNeeded(func); + cloned_branches.push_back(cloned); + if (cloned == func) continue; + // Patch up the op attribute to point to the new function. + for (NamedAttribute attr : op->getAttrs()) { + auto symref = attr.second.dyn_cast(); + if (!symref) continue; + if (symref.getValue() != func.getName()) continue; + op->setAttr(attr.first, + FlatSymbolRefAttr::get(cloned.getName(), op->getContext())); + break; + } + } + + // Traverse results backward so that indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op->getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + for (FuncOp func : cloned_branches) + func.front().getTerminator()->eraseOperand(result_idx); + } + + // Check which function arguments are unused in all branches. We can drop + // those as well. + int num_args = cloned_branches[0].getNumArguments(); + llvm::BitVector used_args(num_args); + for (FuncOp func : branches) { + for (BlockArgument arg : func.getArguments()) { + if (!arg.use_empty()) used_args.set(arg.getArgNumber()); + } + } + + // There are some unused args that we can drop. Also drop the corresponding + // input operand. + if (used_args.count() != num_args) { + // Traverse arguments backward so that indices to be deleted stay unchanged. + for (int idx = num_args - 1; idx >= 0; --idx) { + if (used_args.test(idx)) continue; + for (FuncOp func : cloned_branches) func.eraseArgument(idx); + // For if/case, arg #i of attached function corresponds to operand #i+1 + op->eraseOperand(idx + 1); + } + } + + // Patch up function types (with less number of return values and potentially + // less number of arguments) + for (FuncOp func : cloned_branches) { + func.setType(FunctionType::get( + func.front().getArgumentTypes(), + func.front().getTerminator()->getOperandTypes(), func.getContext())); + } + + EliminateUnusedResults(op); +} + +// Eliminated unused results from a functional while. +void EliminateUnusedResultsForWhile(TF::WhileOp op) { + FuncOp cond = op.cond_function(); + FuncOp body = op.body_function(); + + llvm::BitVector can_eliminate(op.getNumResults()); + for (OpResult result : llvm::reverse(op.getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + BlockArgument cond_arg = cond.getArgument(result_idx); + BlockArgument body_arg = cond.getArgument(result_idx); + Operation *body_ret = body.front().getTerminator(); + // We can eliminate a result if its unused and the corresponding argument + // is unused in cond and the only use in body is use it as a return value. + if (cond_arg.use_empty() && body_arg.hasOneUse() && + body_arg.use_begin()->getOperandNumber() == result_idx && + body_arg.use_begin()->getOwner() == body_ret) { + can_eliminate.set(result_idx); + } + } + + if (can_eliminate.empty()) return; + + FuncOp cloned_cond = CloneFunctionIfNeeded(cond); + FuncOp cloned_body = CloneFunctionIfNeeded(body); + op.condAttr(FlatSymbolRefAttr::get(cloned_cond.getName(), op.getContext())); + op.bodyAttr(FlatSymbolRefAttr::get(cloned_body.getName(), op.getContext())); + + // Drop cond/body args and return value. WhileOp result will be dropped later + // in EliminateUnusedResults. Traverse in reverse order so that indices to be + // deleted stay unchanged. + for (int idx = op.getNumResults() - 1; idx >= 0; --idx) { + if (!can_eliminate.test(idx)) continue; + cloned_cond.eraseArgument(idx); + cloned_body.front().getTerminator()->eraseOperand(idx); + cloned_body.eraseArgument(idx); + } + + // Patch up branch function types. + for (FuncOp func : {cloned_cond, cloned_body}) { + func.setType(FunctionType::get( + func.front().getArgumentTypes(), + func.front().getTerminator()->getOperandTypes(), func.getContext())); + } + EliminateUnusedResults(op, &can_eliminate); +} + +// For resource results, replace all uses with the resource input to which the +// result is tied to. After this, resource outputs of this op are expected to be +// unused. +LogicalResult ForwardCommonArgToOutput(Operation *op, ArrayRef branches, + ValueRange branch_args, + bool &has_resource_result) { + // For while, the branch inputs and outputs need to match. + bool io_match = isa(op); + + has_resource_result = false; + // Check if the same input argument number is passed through all functions. + for (OpResult result : op->getResults()) { + if (!IsResource(result)) continue; + + has_resource_result = true; + int result_idx = result.getResultNumber(); + Optional common_arg_index; + for (FuncOp func : branches) { + auto ret = func.front().getTerminator(); + auto block_arg = ret->getOperand(result_idx).dyn_cast(); + if (!block_arg) { + return op->emitOpError("result #") + << result_idx << " not tied to function argument for branch @" + << func.getName(); + } + if (!common_arg_index.hasValue()) { + common_arg_index = block_arg.getArgNumber(); + } else if (common_arg_index.getValue() != block_arg.getArgNumber()) { + return op->emitError("result #") + << result_idx + << " is not tied to the same argument across all branches"; + } + } + + if (io_match && result_idx != common_arg_index.getValue()) { + return op->emitOpError("Result #") + << result_idx << " is tied to argument #" + << common_arg_index.getValue(); + } + + // Forward the corresponding input to the output + result.replaceAllUsesWith(branch_args[common_arg_index.getValue()]); + } + return success(); +} + +// Canonicalizes a function if. Forwards input argument to resource results and +// then deletes the resource results. +LogicalResult CanonicalizeFunctionalIfCase(Operation *op, + ArrayRef branches, + ValueRange branch_args) { + for (FuncOp func : branches) { + if (failed(CleanupAndCanonicalize(func))) return failure(); + } + + bool has_resource_result = false; + if (failed(ForwardCommonArgToOutput(op, branches, branch_args, + has_resource_result))) + return failure(); + + // If no resource type results were found, no further cleanup needed. + if (!has_resource_result) return success(); + + // Drop unused results. + EliminateUnusedResultsForIfCase(op, branches); + return success(); +} + +// Canonicalizes a functional while. Forwards common argument to results and +// drop resource results if posible. +LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) { + for (FuncOp func : {op.cond_function(), op.body_function()}) { + if (failed(CleanupAndCanonicalize(func))) return failure(); + } + + // For while, just use the body function to forward operand to result. + bool has_resource_result = false; + if (failed(ForwardCommonArgToOutput(op, {op.body_function()}, + op.getOperands(), has_resource_result))) + return failure(); + // If no resource type results were found, no further cleanup needed. + if (!has_resource_result) return success(); + + // Drop unused results. + EliminateUnusedResultsForWhile(op); + return success(); +} + +// Canonicalizes region based if/case and cluster operations. If the same +// captured resource typed value is used for all region results, then that value +// is forwared to the result and the result is dropped. +LogicalResult CanonicalizeRegionIfCaseCluster(Operation *op) { + // Check if the same value is used for all region results for this output. + bool has_resource_result = false; + for (OpResult result : op->getResults()) { + if (!IsResource(result)) continue; + has_resource_result = true; + int result_idx = result.getResultNumber(); + + Value ret0 = + op->getRegion(0).front().getTerminator()->getOperand(result_idx); + for (Region ®ion : op->getRegions().drop_front()) { + Value ret = region.front().getTerminator()->getOperand(result_idx); + if (ret != ret0) { + return op->emitError("Result #") + << result_idx + << " not tied to the same capture across all regions"; + } + } + result.replaceAllUsesWith(ret0); + } + + if (!has_resource_result) return success(); + + // Eliminate unused region results. Traverse in reverse order so that + // indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op->getResults())) { + if (!result.use_empty()) continue; + int result_idx = result.getResultNumber(); + for (Region ®ion : op->getRegions()) + region.front().getTerminator()->eraseOperand(result_idx); + } + EliminateUnusedResults(op); + return success(); +} + +// Canonicalizes a region based while. If the same value is passed through +// the body, the result is replaced with the operand and all argument/results +// and retuns values corresponding to that result are dropped. +LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) { + Region &body = op.body(); + Region &cond = op.cond(); + llvm::BitVector can_eliminate(op.getNumResults()); + + // Traverse in reverse order so that indices to be deleted stay unchanged. + for (OpResult result : llvm::reverse(op.getResults())) { + if (!IsResource(result)) continue; + int result_idx = result.getResultNumber(); + auto body_arg = body.front() + .getTerminator() + ->getOperand(result_idx) + .dyn_cast(); + if (!body_arg || body_arg.getArgNumber() != result_idx) { + return op.emitOpError("Result #") << result_idx << " is not tied to arg #" + << result_idx << " of the body"; + } + body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx)); + cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx)); + body.front().getTerminator()->eraseOperand(result_idx); + body.eraseArgument(result_idx); + cond.eraseArgument(result_idx); + result.replaceAllUsesWith(op.getOperand(result_idx)); + op.getOperation()->eraseOperand(result_idx); + can_eliminate.set(result_idx); + } + EliminateUnusedResults(op, &can_eliminate); + return success(); +} + +// Removes identities and canonicalizes all operations within `parent_op`. +LogicalResult CleanupAndCanonicalize(Operation *parent_op) { + auto walk_result = parent_op->walk([](Operation *op) { + // Cleanup code in attached regions. + for (Region ®ion : op->getRegions()) { + if (!llvm::hasSingleElement(region)) return WalkResult::interrupt(); + RemoveIdentity(region.front()); + RemoveDeadLocalVariables(region.front()); + } + + LogicalResult result = success(); + + // While condition cannot write to resource variables. + auto check_while_cond = [&](TF::AssignVariableOp assign) { + op->emitOpError("found resource write in loop condition."); + return WalkResult::interrupt(); + }; + + if (auto if_op = dyn_cast(op)) { + result = CanonicalizeFunctionalIfCase( + op, {if_op.then_function(), if_op.else_function()}, if_op.input()); + } else if (auto case_op = dyn_cast(op)) { + SmallVector branches; + for (Attribute branch : case_op.branches()) { + auto sym = branch.cast(); + branches.push_back( + SymbolTable::lookupNearestSymbolFrom(op, sym)); + } + result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input()); + } else if (auto while_op = dyn_cast(op)) { + if (while_op.cond_function().walk(check_while_cond).wasInterrupted()) + return WalkResult::interrupt(); + result = CanonicalizeFunctionalWhile(while_op); + } else if (isa( + op)) { + result = CanonicalizeRegionIfCaseCluster(op); + } else if (auto while_region = dyn_cast(op)) { + if (while_region.cond().walk(check_while_cond).wasInterrupted()) + return WalkResult::interrupt(); + // For while region, the body input and output arg should match. + CanonicalizeWhileRegion(while_region); + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) return WalkResult::interrupt(); + result = CleanupAndCanonicalize(func); + } + return failed(result) ? WalkResult::interrupt() : WalkResult::advance(); + }); + + return failure(walk_result.wasInterrupted()); +} + +} // anonymous namespace + +namespace TF { + +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(FuncOp func) { + return CleanupAndCanonicalize(func); +} + +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module) { + auto walk_result = module.walk([](tf_device::ClusterOp cluster) { + if (failed(CleanupAndCanonicalize(cluster))) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(walk_result.wasInterrupted()); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h new file mode 100644 index 00000000000..626ef91bcf6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +// Performs IR cleanup and canonicalization in preparation for Resource Op +// Lifting pass. It does several things: +// - Eliminate identity nodes to remove (most) of resource aliasing +// - Canonicalize functional control flow. For functional control flow we +// expect that any resource output of these ops matches the corresponding +// input, and then forward that input to the output. Fails if this is not the +// case. If successful, the following invariants will hold true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +// - Canonicalize region based control flow. Again, any resource outputs are +// expected to be resolved to be one of the captured resource inputs. Fails +// if this is not the case. If successful, the following invariants will hold +// true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +namespace mlir { +namespace TF { +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module); +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(FuncOp func); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 597fbe2c0b1..eef879ca257 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -40,6 +41,8 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -50,10 +53,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/core/framework/op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/types.pb.h" @@ -115,12 +115,12 @@ Optional> InferShapeForFunctionReturnType(FuncOp func) { // Returns if the shape inference pass supports an op outside the TF dialect. bool IsSupportedNonTFOp(Operation* op) { - return isa(op); + return isa(op); } // Returns whether a cast back would need to be inserted, e.g., whether the @@ -155,57 +155,6 @@ void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type, result.setType(new_type); } -// Extracts a PartialTensorShape from the MLIR type. -Optional GetShapeFromMlirType(Type t) { - if (auto ranked_type = t.dyn_cast()) { - // Convert the MLIR shape indices (int64_t) to TensorFlow indices - // (int64). - ArrayRef shape = ranked_type.getShape(); - SmallVector tf_shape(shape.begin(), shape.end()); - return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()}); - } - return None; -} - -// Gets the subtype's shape and data type for `type`. Templated to support both -// ResourceType and VariantType. -template -std::unique_ptr>> -GetSubtypesHelper(Type type) { - auto type_with_subtypes = - type.cast().getElementType().dyn_cast(); - if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { - return nullptr; - } - auto shapes_and_types = absl::make_unique>>(); - for (auto subtype : type_with_subtypes.getSubtypes()) { - auto shape = GetShapeFromMlirType(subtype); - // handle_shapes_and_types requires all shapes to be known. So if any - // subtype is unknown, clear the vector. - if (!shape) { - shapes_and_types = nullptr; - break; - } - tensorflow::DataType dtype; - auto status = - tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); - assert(status.ok() && "Unknown element type"); - shapes_and_types->emplace_back(*shape, dtype); - } - return shapes_and_types; -} - -// Gets the subtype's shape and data type for `type`. -std::unique_ptr>> -GetSubtypes(Type type) { - auto subclasses = GetSubtypesHelper(type); - if (subclasses) return subclasses; - return GetSubtypesHelper(type); -} - // Returns whether type can be further refined. bool CanBeRefined(Type type) { auto shape_type = type.dyn_cast(); @@ -292,8 +241,8 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) { // function result types. bool InferShapeForIf(IfOp op) { bool changed = false; - auto then_results = op.then_func().getType().getResults(); - auto else_results = op.else_func().getType().getResults(); + auto then_results = op.then_function().getType().getResults(); + auto else_results = op.else_function().getType().getResults(); for (auto it : llvm::zip(op.getResults(), then_results, else_results)) { // If then and else types do not match, skip refinement for that result. if (std::get<1>(it) != std::get<2>(it)) continue; @@ -596,7 +545,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context, bool propagate_caller_callee_constants) : graph_version_(graph_version), propagate_caller_callee_constants_(propagate_caller_callee_constants) { - tf_dialect_ = context->getRegisteredDialect(); + tf_dialect_ = context->getLoadedDialect(); } ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, @@ -697,11 +646,8 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) { // TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the // moment, cannot handle arbirary types (e.g., it can't handle quantized // types). This restriction can be relaxed if not only tf.Cast is used. - auto kind = t.getKind(); - return (kind >= Type::FIRST_STANDARD_TYPE && - kind < Type::LAST_STANDARD_TYPE) || - (kind >= Type::FIRST_TENSORFLOW_TYPE && - kind < Type::LAST_TENSORFLOW_TYPE); + return t.getDialect().getNamespace().empty() || + isa(t.getDialect()); }; bool changed = false; @@ -747,6 +693,11 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { return RefineTypeForPassThroughOperands(op, terminator->getOperands(), op->getResults()); } + if (auto cluster_op = dyn_cast(op)) { + auto terminator = cluster_op.GetBody().getTerminator(); + return RefineTypeForPassThroughOperands(op, terminator->getOperands(), + op->getResults()); + } if (op->hasTrait()) { return RefineShapeForPassThroughOps(op); } @@ -796,182 +747,54 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { if (auto if_region = dyn_cast(op)) return InferShapeForIfRegion(if_region); - StringRef op_name = op->getName().getStringRef(); - // Drop the `tf.` prefix to query TF registry. - auto node_name = - op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1); - - // Get information from the registry and check if we have a shape function for - // this op. - const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(node_name.data()); - if (!op_reg_data) { - LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" - << op->getName() << "'.\n"); - return false; - } - if (op_reg_data->shape_inference_fn == nullptr) { - LLVM_DEBUG(llvm::dbgs() - << "Skipping inference for op without shape function '" - << op->getName() << "'.\n"); - return false; - } - - // Convert the operation to a NodeDef to be able to use the InferenceContext - // and the TensorFlow shape function. - auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( - op, node_name, /*ignore_unregistered_attrs=*/true); - if (!node_def_or.ok()) { - LLVM_DEBUG(llvm::dbgs() - << "Error converting op '" << *op << "' to NodeDef: " - << node_def_or.status().error_message() << "\n"); - return false; - } - std::unique_ptr node_def = - std::move(node_def_or).ValueOrDie(); - - // Collect an array with input values for constant operands and input shapes - // for all the operands. - std::vector input_tensors(op->getNumOperands()); - std::vector input_shapes( - op->getNumOperands()); - std::vector tensors(op->getNumOperands()); - std::vector>>> - handle_shapes_and_types(op->getNumOperands()); - for (auto it : llvm::enumerate(op->getOperands())) { - Value operand = it.value(); - size_t index = it.index(); - - // If the operand is constant, then convert it to Tensor. + // Return operand as a constant attribute. + auto operand_as_constant_fn = [&](Value operand) { ValuePort vp(operand); Attribute attr = ComputeOutputComponent(vp); if (!attr && matchPattern(operand, m_Constant(&attr))) RecordValue(vp, attr); - if (attr) { - tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = - tensorflow::ConvertToTensor(attr.cast(), input_tensor); - if (status.ok()) { - input_tensors[index] = input_tensor; - } else { - LLVM_DEBUG(llvm::dbgs() - << "Error converting input " << index << " of op '" << *op - << "' to Tensor: " << status.error_message() << "\n"); - } - } + return attr; + }; - Type operand_type = operand.getType(); - if (auto shape = GetShapeFromMlirType(operand_type)) { - input_shapes[index] = *shape; - } - // Collect the handle shapes and types for a resource/variant. - handle_shapes_and_types[index] = GetSubtypes(operand_type); - } + // Return op result as a shape. + auto op_result_as_shape_fn = [&](InferenceContext& context, + OpResult op_result) { + return ComputeOutputAsShape(op_result, &context); + }; - // Perform the shape inference using an InferenceContext with the input - // shapes. This object is abstracting the information that the ShapeInference - // function operates on. - InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, - input_shapes, input_tensors, - /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); - auto status = c.Run(op_reg_data->shape_inference_fn); - if (!status.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op - << "': " << status.error_message() << "\n"); + // Return result element type at `index`. + auto result_element_type_fn = [&](int index) { + return op->getResult(index).getType().cast().getElementType(); + }; + + llvm::SmallVector inferred_return_shapes; + if (failed(InferReturnTypeComponentsForTFOp( + /*location=*/None, op, graph_version_, operand_as_constant_fn, + op_result_as_shape_fn, result_element_type_fn, + inferred_return_shapes))) return false; - } - - // Determine if, during shape computation, the shape functions attempted to - // query an input operand as shape where the input was not known/constant. - bool requires_inputs = - any_of(llvm::seq(0, c.num_inputs()), [&](int input) { - return c.requested_input_tensor_as_partial_shape(input) && - !input_tensors[input]; - }); - if (requires_inputs) { - LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); - std::vector input_tensors_as_shapes; - for (int input : llvm::seq(0, c.num_inputs())) { - if (c.requested_input_tensor_as_partial_shape(input) && - !input_tensors[input]) { - LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); - auto op_result = op->getOperand(input).dyn_cast(); - if (!op_result) continue; - // Resize on first valid shape computed. - input_tensors_as_shapes.resize(c.num_inputs()); - auto handle = ComputeOutputAsShape(op_result, &c); - LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " - << (handle.Handle() ? "found" : "not found")); - if (handle.Handle()) input_tensors_as_shapes[input] = handle; - } - } - - // Attempt to compute the unknown operands as shapes. - // Note: in the case where no partial outputs could be computed, this would - // be empty. - if (!input_tensors_as_shapes.empty()) { - c.set_input_tensors_as_shapes(input_tensors_as_shapes); - auto status = c.Run(op_reg_data->shape_inference_fn); - if (!status.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op - << "': " << status.error_message() << "\n"); - return false; - } - } - } - - assert(c.num_outputs() == op->getNumResults() && - "inference context matches the MLIR number of results."); // Update the shape for each of the operation result if the InferenceContext // has more precise shapes recorded. bool changed = false; - for (int output : llvm::seq(0, c.num_outputs())) { - // Skip already statically shaped results. - Value result = op->getResult(output); - if (!CanBeRefined(result.getType())) continue; - auto shaped_type = result.getType().cast(); + for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) { + Value op_result = std::get<0>(result); + if (!CanBeRefined(op_result.getType())) continue; - ShapeHandle shape_handle = c.output(output); - LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " - << c.DebugString(shape_handle) << "\n"); - auto get_tensor_type = [&c](const ShapeHandle& sh, - Type element_type) -> TensorType { - if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type); - // Convert the shape from TensorFlow (int64) to MLIR (int64_t). - SmallVector shape; - for (int dim : llvm::seq(0, c.Rank(sh))) - shape.push_back(c.Value(c.Dim(sh, dim))); - return RankedTensorType::get(shape, element_type); - }; - auto new_element_type = shaped_type.getElementType(); - // Populate the handle shapes for a resource/variant. - if (new_element_type.isa()) { - auto handle_shapes_types = c.output_handle_shapes_and_types(output); - if (handle_shapes_types) { - SmallVector subtypes; - OpBuilder b(op); - for (const auto& shape_n_type : *handle_shapes_types) { - Type element_type; - auto status = - tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type); - assert(status.ok() && "Unknown element type"); - subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type)); - } - if (new_element_type.isa()) { - new_element_type = TF::ResourceType::get(subtypes, op->getContext()); - } else { - new_element_type = TF::VariantType::get(subtypes, op->getContext()); - } - } - } - auto new_type = get_tensor_type(shape_handle, new_element_type); - if (result.getType() == new_type) continue; + ShapedTypeComponents inferred = std::get<1>(result); + TensorType inferred_type; + if (inferred.hasRank()) + inferred_type = + RankedTensorType::get(inferred.getDims(), inferred.getElementType()); + else + inferred_type = UnrankedTensorType::get(inferred.getElementType()); - UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, op, result); + if (op_result.getType() == inferred_type) continue; + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, inferred_type, op, + op_result); changed = true; } + if (changed) LLVM_DEBUG(llvm::dbgs() << "Modified after shape inference: '" << *op << "'\n"); @@ -1101,7 +924,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( module, drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_func(), if_op.else_func()}, max_iteration); + {if_op.then_function(), if_op.else_function()}, max_iteration); } else if (auto case_op = dyn_cast(op)) { SmallVector branches; for (Attribute branch : case_op.branches()) { @@ -1114,7 +937,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( } else if (auto while_op = dyn_cast(op)) { return PropagateShapeToFunctions( module, while_op.getOperandTypes(), - {while_op.cond_func(), while_op.body_func()}, max_iteration); + {while_op.cond_function(), while_op.body_function()}, max_iteration); } else if (auto call_op = dyn_cast(op)) { if (auto func = dyn_cast(call_op.resolveCallable())) { PropagateConstantToCallee(call_op, func, module); @@ -1174,10 +997,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { if (!dialect) return failure(); // Only attempt TF dialect fallback if there are no unknown operands. if (some_unknown && dialect == tf_dialect_) return failure(); - SmallVector constants; - if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + auto* interface = dialect->getRegisteredInterface(); + if (!interface) return failure(); + + if (failed(interface->fold(op, constant_operands, fold_results))) return failure(); - fold_results.assign(constants.begin(), constants.end()); } for (auto result : zip(op->getResults(), fold_results)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index d3755a4a7d0..a0651c6d013 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -163,7 +163,7 @@ LogicalResult HandleWhileOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = while_op.body_func(); + auto body = while_op.body_function(); llvm::SmallDenseMap body_map; auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(while_op.getOperand(index)); @@ -187,7 +187,7 @@ LogicalResult HandleWhileOp( return failure(); } // Cond should not change stacks in the arguments, so use an empty map. - auto cond = while_op.cond_func(); + auto cond = while_op.cond_function(); ModifyFunctionSignature(cond, nullptr, find_arg_stack_type); llvm::SmallDenseMap empty_map; if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map, @@ -231,8 +231,8 @@ LogicalResult HandleIfOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_func = if_op.then_func(); - auto else_func = if_op.else_func(); + auto then_func = if_op.then_function(); + auto else_func = if_op.else_function(); llvm::SmallDenseMap then_map; llvm::SmallDenseMap else_map; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index b3a05c06a67..01de6a89c83 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -443,12 +443,12 @@ llvm::SmallDenseMap> AccessedGradients( insert(grad.handle(), grad.source().str()); } else if (auto while_op = llvm::dyn_cast(&op)) { for (const auto& entry : AccessedGradients( - {while_op.body_func(), while_op.cond_func()}, module)) + {while_op.body_function(), while_op.cond_function()}, module)) for (const string& source : entry.getSecond()) insert(while_op.getOperand(entry.getFirst()), source); } else if (auto if_op = llvm::dyn_cast(&op)) { - for (const auto& entry : - AccessedGradients({if_op.then_func(), if_op.else_func()}, module)) + for (const auto& entry : AccessedGradients( + {if_op.then_function(), if_op.else_function()}, module)) for (const string& source : entry.getSecond()) insert(if_op.getOperand(entry.getFirst() + 1), source); } else if (auto call = llvm::dyn_cast(&op)) { @@ -509,8 +509,8 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = while_op.body_func(); - auto cond = while_op.cond_func(); + auto body = while_op.body_function(); + auto cond = while_op.cond_function(); auto grads = AccessedGradients({body, cond}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(while_op.getOperand(index)); @@ -592,8 +592,8 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_branch = if_op.then_func(); - auto else_branch = if_op.else_func(); + auto then_branch = if_op.then_function(); + auto else_branch = if_op.else_function(); auto grads = AccessedGradients({then_branch, else_branch}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(if_op.getOperand(index + 1)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc new file mode 100644 index 00000000000..f14efeb91ce --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" + +namespace mlir { +namespace TF { +namespace { + +// Deletes the op and forwards the arguments. +template +class PassThroughConversion : public mlir::OpConversionPattern { + public: + explicit PassThroughConversion(MLIRContext *context) + : mlir::OpConversionPattern(context) {} + + LogicalResult matchAndRewrite( + TF_Op op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { // NOLINT + // Just forward the arguments to results. + rewriter.replaceOp(op, operands); + return success(); + } +}; + +class TensorDeviceCopyConversionPass + : public PassWrapper { + public: + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::ConversionTarget target(getContext()); + + // TODO(tfrt-devs): when device placer is introduced in the lowering pass, + // we need to check if Identity op and it's previous op are placed on the + // same device. If not, we don't fold Identity op since it's used for tensor + // copying between devices. + patterns.insert, + PassThroughConversion>(&getContext()); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateTensorDeviceCopyConversionPass() { + return std::make_unique(); +} + +static mlir::PassRegistration + tensor_device_copy_pass( + "tf-tensor-device-copy", + "Handle ops that copy tensors between devices. E.g., tf.Identity."); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 9634e4a8be3..da6757e6c94 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -155,7 +155,7 @@ LogicalResult HandleWhileOp( llvm::StringMap* decomposed_partitioned_call_callees) { // Rewrite body. - auto body = while_op.body_func(); + auto body = while_op.body_function(); llvm::SmallDenseMap body_map; auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(while_op.getOperand(index)); @@ -176,7 +176,7 @@ LogicalResult HandleWhileOp( auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map); // Rewrite cond. - auto cond = while_op.cond_func(); + auto cond = while_op.cond_function(); llvm::SmallDenseMap cond_map; ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map, find_arg_tensor_list_type, arg_buffer_size_is_fixed); @@ -701,9 +701,9 @@ LogicalResult DecomposeTensorListOpsInternal( return failure(); } } else if (auto if_op = llvm::dyn_cast(&op)) { - if (failed(HandleCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()}, - module, buffer_to_size, - decomposed_partitioned_call_callees))) { + if (failed(HandleCaseOrIfOp( + if_op, {if_op.then_function(), if_op.else_function()}, module, + buffer_to_size, decomposed_partitioned_call_callees))) { return failure(); } } else if (auto case_op = llvm::dyn_cast(&op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc new file mode 100644 index 00000000000..920b2024c0f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc @@ -0,0 +1,111 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +// A pass that annotates each operation with a resource type result with the +// aliasing values for each such result. Each value is assigned a unique ID, and +// that ID is used to annotate the operations. +struct TestResourceAliasAnalysis + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TestResourceAliasAnalysis, TF::ResourceAliasAnalysis> { + void runOnFunction(FuncOp func, + const TF::ResourceAliasAnalysis::Info& analysis) { + int64_t next_id = 0; + llvm::SmallDenseMap ids; + + auto assign_id = [&](Value value) { + if (ids.find(value) == ids.end()) ids.insert({value, next_id++}); + }; + + auto get_id = [&](Value value) -> int64_t { + auto it = ids.find(value); + assert(it != ids.end()); + return it->second; + }; + + auto print_aliases = [&](InFlightDiagnostic& diag, Value value) { + diag << ", ID " << get_id(value) << " : "; + if (analysis.IsUnknownResource(value)) { + diag << "Unknown"; + } else { + auto aliases = llvm::to_vector<4>(analysis.GetResourceAliases(value)); + llvm::sort(aliases, + [&](Value v1, Value v2) { return get_id(v1) < get_id(v2); }); + llvm::interleaveComma(aliases, diag, + [&](Value v) { diag << get_id(v); }); + } + }; + + // Assign a unique ID to each value seen in this function. + func.walk([&](Operation* op) { + // For all attached regions, assign ID to the region arguments. + for (Region& region : op->getRegions()) { + for (auto region_arg : filter_resources(region.getArguments())) + assign_id(region_arg); + } + + // Assign ID for all results. + for (auto result : filter_resources(op->getResults())) assign_id(result); + }); + + // Now walk each operation, and annotate it wil remarks for aliases for + // each resource type result + func.walk([&](Operation* op) { + // For all attached regions, assign ID to the region arguments. + for (Region& region : op->getRegions()) { + for (auto region_arg : filter_resources(region.getArguments())) { + InFlightDiagnostic diag = op->emitRemark("Region #") + << region.getRegionNumber() << ", Arg #" + << region_arg.getArgNumber(); + print_aliases(diag, region_arg); + } + } + + for (auto result : filter_resources(op->getResults())) { + InFlightDiagnostic diag = op->emitRemark("Result #") + << result.getResultNumber(); + print_aliases(diag, result); + } + }); + } +}; + +static mlir::PassRegistration pass( + "tf-test-resource-alias-analysis", + "Add remarks based on resource alias analysis result, for testing " + "purpose."); + +} // anonymous namespace +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc new file mode 100644 index 00000000000..689becb796b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +namespace tensorflow { +namespace { + +std::string get_stage_description(const WalkStage &stage) { + if (stage.IsBeforeAllRegions()) return "before all regions"; + if (stage.IsAfterAllRegions()) return "after all regions"; + return "before region #" + std::to_string(stage.GetNextRegion()); +} + +// A pass that annotates each operation with an remarks that include a unique +// step ID and a description of the visitor step. +class TestVisitorUtil + : public mlir::PassWrapper { + public: + void runOnFunction() override { + mlir::FuncOp func = getOperation(); + int step_id = 0; + GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) { + op->emitRemark() << step_id++ << ": " << get_stage_description(stage); + }); + + // Exercise static inference of operation type + GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) { + op.emitRemark() << step_id++ << ": " << get_stage_description(stage); + }); + } +}; + +class TestVisitorUtilInterrupt + : public mlir::PassWrapper { + public: + void runOnFunction() override { + mlir::FuncOp func = getOperation(); + int step_id = 0; + + auto walker = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto interrupt_before_all = + op->getAttrOfType("interrupt_before_all")) + if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions()) + return mlir::WalkResult::interrupt(); + + if (auto interrupt_after_all = + op->getAttrOfType("interrupt_after_all")) + if (interrupt_after_all.getValue() && stage.IsAfterAllRegions()) + return mlir::WalkResult::interrupt(); + + if (auto interrupt_after_region = + op->getAttrOfType("interrupt_after_region")) + if (stage.IsAfterRegion( + static_cast(interrupt_after_region.getInt()))) + return mlir::WalkResult::interrupt(); + + op->emitRemark() << step_id++ << ": " << get_stage_description(stage); + return mlir::WalkResult::advance(); + }; + + // Interrupt the walk based on attributes on the operation. + auto result = GenericWalk(func, walker); + + if (result.wasInterrupted()) + func.emitRemark() << step_id++ << ": walk was interrupted"; + + // Exercise static inference of operation type for interrupting callback. + result = + GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) { + return walker(op, stage); + }); + + if (result.wasInterrupted()) + func.emitRemark() << step_id++ << ": walk was interrupted"; + } +}; + +mlir::PassRegistration pass( + "tf-test-visitor-util", + "Add remarks that trace order of visiting operations using TF visitor " + "utilities."); + +mlir::PassRegistration pass_interrupt( + "tf-test-visitor-util-interrupt", + "Add remarks that trace order of visiting operations using TF visitor " + "utilities, interrupt version."); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc index 2a770b2615d..f26887eb276 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -34,7 +34,7 @@ class SimpleTFDeviceAssignmentPass void runOnFunction() override { Builder builder(&getContext()); - Dialect* tf = getContext().getRegisteredDialect(); + Dialect* tf = getContext().getLoadedDialect(); getFunction().walk([&](Operation* op) { if (auto device_attr = op->getAttrOfType("device")) { // We assign default device to ops with device attribute that is empty. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1e4caaf5dd6..52ac87ecf71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -43,6 +44,10 @@ namespace tensorflow { class GraphOptPass : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + mlir::RegisterAllTensorFlowDialects(registry); + } + public: explicit GraphOptPass(std::vector passes) : passes_(std::move(passes)) {} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc new file mode 100644 index 00000000000..93098acdc9d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +// This pass eliminate `_tpu_replicate` and `device` attribute on operations +// that are contained in a tf_device.cluster op. + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kTPUReplicateAttr[] = "_tpu_replicate"; +constexpr char kDeviceAttr[] = "device"; + +class TPUCleanupClusterAttributesPass + : public PassWrapper> { + public: + void runOnOperation() override { + getOperation().walk([](tf_device::ClusterOp cluster) { + cluster.walk([](Operation *op) { + if (isa(op)) return; + for (StringRef attr : {kTPUReplicateAttr, kDeviceAttr}) + op->removeAttr(attr); + }); + }); + } +}; + +PassRegistration pass( + "tf-tpu-cleanup-cluster-attributes", + "Eliminate _tpu_replicate and other attributes from ops in a cluster"); + +} // namespace + +std::unique_ptr> +CreateTPUClusterCleanupAttributesPass() { + return std::make_unique(); +} + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 162ecd77d4f..c3f40154c79 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -48,6 +48,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -70,55 +71,62 @@ constexpr char kBadTPUReplicateAttrMsg[] = using MetadataMap = llvm::SmallDenseMap; +// A set of operations in a cluster. +using ClusterOps = llvm::SmallSetVector; + // Mapping for `_tpu_replicate` attribute to ops of a cluster. -using ClusterMap = llvm::SmallDenseMap, 8>; +using ClusterMap = llvm::SmallDenseMap; struct TPUClusterFormation - : public PassWrapper { - void runOnFunction() override; + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TPUClusterFormation, TF::ResourceAliasAnalysis> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnFunction( + FuncOp func, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); }; // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate` // attribute to its attributes and removes the ops. If multiple // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error // will be returned. -LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) { - auto result = - op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult { - MutableDictionaryAttr attrs = metadata_op.getAttrs(); +LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { + // Just look at top-level operations in the block (not nested ones) + for (Operation& op : llvm::make_early_inc_range(*block)) { + auto metadata_op = dyn_cast(op); + if (!metadata_op) continue; - // Missing or bad `_tpu_replicate` attribute. - auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); - if (!tpu_replicate_attr) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + MutableDictionaryAttr attrs = metadata_op.getAttrs(); - auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); - if (!tpu_replicate_attr_str || - tpu_replicate_attr_str.getValue().empty()) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + // Missing or bad `_tpu_replicate` attribute. + auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); + if (!tpu_replicate_attr) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - // Remove `name` attribute. - attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); + auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); + if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty()) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), - std::move(attrs)); + // Remove `name` attribute. + attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); - // There are multiple TPUReplicateMetadata ops with the same - // `_tpu_replicate` attribute. - if (!it.second) { - return metadata_op.emitError() - << "multiple TPUReplicateMetadata ops with the same '" - << kTPUReplicateAttr << "' attribute '" - << tpu_replicate_attr_str.getValue() << "' found"; - } + auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), + std::move(attrs)); - metadata_op.erase(); - return WalkResult::advance(); - }); - - // Return failure if the walk was interrupted. - return failure(result.wasInterrupted()); + // There are multiple TPUReplicateMetadata ops with the same + // `_tpu_replicate` attribute. + if (!it.second) { + return metadata_op.emitError() + << "multiple TPUReplicateMetadata ops with the same '" + << kTPUReplicateAttr << "' attribute '" + << tpu_replicate_attr_str.getValue() << "' found"; + } + metadata_op.erase(); + } + return success(); } // Collects and clusters ops with the same `_tpu_replicate` attribute. This will @@ -138,14 +146,34 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { return success(); } +// Collects all resource ids from an op. +void CollectResourceIdsFromOp( + Operation& op, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis, + llvm::SmallDenseSet& observed_resource_ids) { + op.walk([&](Operation* inner_op) { + for (Value operand : TF::filter_resources(inner_op->getOperands())) { + if (resource_alias_analysis.IsUnknownResource(operand)) continue; + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand); + observed_resource_ids.insert(ids.begin(), ids.end()); + } + for (Value result : TF::filter_resources(inner_op->getResults())) { + if (resource_alias_analysis.IsUnknownResource(result)) continue; + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result); + observed_resource_ids.insert(ids.begin(), ids.end()); + } + }); +} + // Checks if an op should be moved after a cluster. There may be users of a // cluster interleaved among the cluster ops. bool ShouldMoveOpAfterCluster( - Block* block, Operation* op, - const llvm::SmallSetVector& cluster_ops, - const llvm::SmallSetVector& preceding_users) { - auto result = op->walk([&](Operation* op) { - for (Value operand : op->getOperands()) { + Block* block, Operation* op, const ClusterOps& cluster_ops, + const llvm::SmallSetVector& preceding_users, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis, + const llvm::SmallDenseSet& observed_resource_ids) { + auto result = op->walk([&](Operation* inner_op) { + for (Value operand : inner_op->getOperands()) { Operation* def = operand.getDefiningOp(); // Operands may not have a defining op (BlockArgument) or is from a // different block. @@ -157,6 +185,14 @@ bool ShouldMoveOpAfterCluster( return WalkResult::interrupt(); } } + + // Check for uses of any resource in or after cluster. + for (Value operand : TF::filter_resources(inner_op->getOperands())) { + if (resource_alias_analysis.IsUnknownResource(operand)) continue; + auto ids = resource_alias_analysis.GetResourceUniqueIds(operand); + for (const auto& id : ids) + if (observed_resource_ids.contains(id)) return WalkResult::interrupt(); + } return WalkResult::advance(); }); @@ -165,16 +201,31 @@ bool ShouldMoveOpAfterCluster( // Collects ops that are before ops in the cluster but are users of other ops // in the cluster. This may happen because users of individual ops in the -// cluster may be interleaved with other ops in the cluster. +// cluster may be interleaved with other ops in the cluster. Resource id's are +// also captured, to keep track of resource usage before, in, or after the +// cluster. +// TODO(lyandy): Extend this to handle all side effecting ops while handling +// transitive data dependencies. llvm::SmallSetVector CollectClusterPrecedingUsers( - Block* block, const llvm::SmallSetVector& cluster_ops) { + Block* block, const ClusterOps& cluster_ops, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { llvm::SmallSetVector preceding_users; + llvm::SmallDenseSet observed_resource_ids; - for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()), - Block::iterator(cluster_ops.back()))) - if (cluster_ops.count(&op) == 0 && - ShouldMoveOpAfterCluster(block, &op, cluster_ops, preceding_users)) + auto front = Block::iterator(cluster_ops.front()); + auto back = Block::iterator(cluster_ops.back()); + for (Operation& op : llvm::make_range(front, back)) { + if (cluster_ops.contains(&op)) { + CollectResourceIdsFromOp(op, resource_alias_analysis, + observed_resource_ids); + } else if (ShouldMoveOpAfterCluster( + block, &op, cluster_ops, preceding_users, + resource_alias_analysis, observed_resource_ids)) { preceding_users.insert(&op); + CollectResourceIdsFromOp(op, resource_alias_analysis, + observed_resource_ids); + } + } return preceding_users; } @@ -185,7 +236,7 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. llvm::SmallVector CollectClusterResults( - Block* block, const llvm::SmallSetVector& cluster_ops) { + Block* block, const ClusterOps& cluster_ops) { llvm::SmallVector results; for (Operation* op : cluster_ops) { @@ -204,61 +255,52 @@ llvm::SmallVector CollectClusterResults( } // Creates a `tf_device.cluster` to wrap cluster ops. -tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { +tf_device::ClusterOp CreateClusterOp( + Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef results, + llvm::ArrayRef preceding_users) { // `tf_device.cluster` will be placed at where the last op of the cluster is. + Operation* last_cluster_op = cluster_ops.back(); OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; for (Value result : results) result_types.push_back(result.getType()); - auto cluster = builder.create(last_cluster_op->getLoc(), result_types); - cluster.body().push_back(new Block); + Block* body = new Block; + cluster.body().push_back(body); + + // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and + // `device` attribute from ops in the cluster as that information will be + // present in the `tf_device.cluster`. Do this for all ops including nested + // ops. + for (Operation* cluster_op : cluster_ops) { + cluster_op->moveBefore(body, body->end()); + cluster_op->walk([&](Operation* inner_op) { + inner_op->removeAttr(kTPUReplicateAttr); + inner_op->removeAttr(kDeviceAttr); + }); + } // Add terminator. - builder.setInsertionPointToEnd(&cluster.GetBody()); + builder.setInsertionPointToEnd(body); builder.create(last_cluster_op->getLoc(), results); - return cluster; -} - -// Moves cluster ops to associated `tf_device.cluster` body. -void MoveClusterOpsToCluster( - tf_device::ClusterOp cluster, - const llvm::SmallSetVector& cluster_ops) { - MLIRContext* context = cluster.getContext(); - Operation* terminator = cluster.GetBody().getTerminator(); - - for (Operation* cluster_op : cluster_ops) { - // Remove `_tpu_replicate` and `device` attribute from ops in the cluster - // as that information will be present in the `tf_device.cluster`. - cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context)); - cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); - } -} - -// Replaces uses of cluster ops results outside of cluster with the associated -// `tf_device.cluster` results. -void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster, - llvm::ArrayRef results) { - Block& cluster_block = cluster.GetBody(); + // Replaces uses of cluster ops results outside of cluster with the associated + // `tf_device.cluster` results. for (auto ret_vals : llvm::zip(results, cluster.getResults())) { Value old_ret = std::get<0>(ret_vals); Value new_ret = std::get<1>(ret_vals); - for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) - if (!cluster_block.findAncestorOpInBlock(*use.getOwner())) - use.set(new_ret); + for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) { + Operation* user = use.getOwner(); + if (!body->findAncestorOpInBlock(*user)) use.set(new_ret); + } } -} -// Moves users of cluster that are before the cluster to after the cluster. -void MovePrecedingClusterUsers(tf_device::ClusterOp cluster, - llvm::ArrayRef preceding_users) { + // Move users of cluster that are before the cluster to after the cluster. Operation* op_after_cluster = cluster.getOperation()->getNextNode(); for (Operation* user : preceding_users) user->moveBefore(op_after_cluster); + return cluster; } // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` @@ -271,8 +313,7 @@ LogicalResult SortTPUReplicatedInputsByIndex( llvm::SmallVectorImpl* sorted_inputs) { llvm::SmallDenseSet unique_indices; for (Operation* input : inputs) { - int64_t index = - llvm::cast(input).index().getSExtValue(); + int64_t index = llvm::cast(input).index(); if (index < -1) return input->emitOpError() << "requires index to be at least -1, but got " << index; @@ -291,10 +332,8 @@ LogicalResult SortTPUReplicatedInputsByIndex( std::stable_sort( sorted_inputs->begin(), sorted_inputs->end(), [](Operation* l, Operation* r) { - int64_t l_index = - llvm::cast(l).index().getSExtValue(); - int64_t r_index = - llvm::cast(r).index().getSExtValue(); + int64_t l_index = llvm::cast(l).index(); + int64_t r_index = llvm::cast(r).index(); if (l_index == -1 && r_index != -1) return false; if (r_index == -1 && l_index != -1) return true; return l_index < r_index; @@ -350,8 +389,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { return input->emitOpError() << "requires " << num_inputs << " operands"; auto tpu_replicated_input = llvm::cast(input); - int64_t tpu_replicated_input_index = - tpu_replicated_input.index().getSExtValue(); + int64_t tpu_replicated_input_index = tpu_replicated_input.index(); if (is_packed) { packed_inputs.push_back(input->getOperand(0)); packed_input_indices.push_back(tpu_replicated_input_index); @@ -442,10 +480,30 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if // attribute `num_replicas` is greater than 1. // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`. -LogicalResult FormClustersInBlock(Block* block, - const MetadataMap& metadata_map) { +LogicalResult FormClustersInBlock( + Block* block, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { + MetadataMap metadata_map; + LogicalResult result = CollectMetadata(block, &metadata_map); + if (failed(result)) return result; + + // If there is no TPUReplicateMetadata op in this block, process blocks in + // regions attached to the op's in the block. + if (metadata_map.empty()) { + for (Operation& op : *block) { + for (Region& region : op.getRegions()) { + if (!llvm::hasSingleElement(region)) + return op.emitOpError("Expected single block region"); + if (failed( + FormClustersInBlock(®ion.front(), resource_alias_analysis))) + return failure(); + } + } + return success(); + } + ClusterMap clusters; - LogicalResult result = CollectAndGroupClusterOps(block, &clusters); + result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; for (const auto& cluster_metadata_and_ops : clusters) { @@ -464,19 +522,14 @@ LogicalResult FormClustersInBlock(Block* block, } llvm::SmallSetVector preceding_users = - CollectClusterPrecedingUsers(block, cluster_ops); + CollectClusterPrecedingUsers(block, cluster_ops, + resource_alias_analysis); llvm::SmallVector results = CollectClusterResults(block, cluster_ops); - tf_device::ClusterOp cluster = - CreateOpForCluster(cluster_ops.back(), results); - - MoveClusterOpsToCluster(cluster, cluster_ops); - - UpdateClusterResultExternalUses(cluster, results); - - MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef()); + tf_device::ClusterOp cluster = CreateClusterOp( + block, cluster_ops, results, preceding_users.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); if (!num_replicas || !num_replicas.isa()) @@ -496,17 +549,19 @@ LogicalResult FormClustersInBlock(Block* block, return success(); } -void TPUClusterFormation::runOnFunction() { - MetadataMap metadata_map; - if (failed(CollectMetadata(getFunction(), &metadata_map))) +void TPUClusterFormation::runOnFunction( + FuncOp func, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { + if (!llvm::hasSingleElement(func)) { + func.emitOpError("Expecting a single block function"); + return signalPassFailure(); + } + + if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis))) return signalPassFailure(); - for (Block& block : getFunction()) - if (failed(FormClustersInBlock(&block, metadata_map))) - return signalPassFailure(); - // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. - auto remove_result = getFunction().walk([&](Operation* op) { + auto remove_result = func.walk([&](Operation* op) { if (!llvm::isa(op)) return WalkResult::advance(); @@ -533,7 +588,7 @@ void TPUClusterFormation::runOnFunction() { } } // anonymous namespace -std::unique_ptr> CreateTPUClusterFormationPass() { +std::unique_ptr> CreateTPUClusterFormationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc new file mode 100644 index 00000000000..b4889f6e52c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { +namespace { + +// Pass that co-locates resource ops that use composite device resources +// (packed tensors) with the underlying physical TPU device. +struct TPUColocateCompositeResourceOps + : public PassWrapper { + void runOnFunction() override; +}; + +// Wraps single op in `tf_device.launch` for explicit device assignment. +void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op, + llvm::StringRef device) { + builder->setInsertionPoint(op); + auto launch = builder->create( + loc, builder->getStringAttr(device), op->getResultTypes()); + launch.body().push_back(new Block); + op->replaceAllUsesWith(launch); + + builder->setInsertionPointToEnd(&launch.GetBody()); + builder->create(loc, op->getResults()); + + // Move op inside cluster. + op->moveBefore(launch.GetBody().getTerminator()); +} + +llvm::SmallVector GetResourceOpsUsingCompositeArgsInReplicate( + tf_device::ReplicateOp replicate) { + llvm::SmallVector resource_users; + const auto add_resource_op_to_list = [&resource_users](Operation* op) { + if (!llvm::isa(op)) return; + + resource_users.emplace_back(op); + }; + + llvm::SmallVector resource_users_to_visit; + for (auto composite_arguments : replicate.GetPackedBlockArguments()) { + for (auto resource_user : composite_arguments.getUsers()) + resource_users_to_visit.emplace_back(resource_user); + } + + while (!resource_users_to_visit.empty()) { + llvm::SmallVector new_resource_users; + + for (auto resource_user : resource_users_to_visit) { + add_resource_op_to_list(resource_user); + + // Account for pass-through identity ops. + if (auto pass_through_identity = + llvm::dyn_cast(resource_user)) { + for (auto identity_user : pass_through_identity.output().getUsers()) { + new_resource_users.emplace_back(identity_user); + } + } + } + resource_users_to_visit.swap(new_resource_users); + } + + return resource_users; +} + +void ColocateCompositeResourceOpsInReplicate( + tf_device::ReplicateOp replicate_op, OpBuilder* builder) { + auto devices = replicate_op.devices(); + if (!devices) return; + if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0))) + return; + + const auto composite_resource_users = + GetResourceOpsUsingCompositeArgsInReplicate(replicate_op); + for (auto resource_user : composite_resource_users) { + WrapOpInLaunch(builder, resource_user->getLoc(), resource_user, + tensorflow::GetDeviceAliasForLogicalCore(0)); + } +} + +void TPUColocateCompositeResourceOps::runOnFunction() { + // Find all the executes first, since we will mutate the nodes around each + // execute in the same tf_device.replicate op. + llvm::SmallVector execute_launches; + getFunction().walk([&](tf_device::LaunchOp op) { + if (op.WrapsSingleOp() && + llvm::isa( + op.GetBody().front())) + execute_launches.push_back(op); + }); + + OpBuilder builder(&getContext()); + for (auto execute_launch : execute_launches) { + auto replicate = execute_launch.getParentOfType(); + if (!replicate) continue; + + ColocateCompositeResourceOpsInReplicate(replicate, &builder); + } +} + +} // namespace + +std::unique_ptr> CreateTPUColocateCompositeResourceOps() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-colocate-composite-resource-ops", + "Colocate resource with composite device assignment to TPU device."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 41362465cd9..59f36e03fbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -185,7 +185,7 @@ bool HandleReplicatedInputs( const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { // We need to know the devices to copy to. if (!replicate.devices()) return false; - int64_t num_replicas = replicate.n().getZExtValue(); + int64_t num_replicas = replicate.n(); auto inputs = replicate.getOperands() .drop_front(replicate_arg_index * num_replicas) .take_front(num_replicas); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 2be6ee7a78c..6e106b278fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -23,10 +23,10 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -34,7 +34,9 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" @@ -113,12 +115,23 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, return launch; } +// Checks if an operation is a supported TPU embedding op. +bool IsEmbeddingOp(Operation* op) { + return isa(op); +} + // Returns a set of ops that are outside compiled and can be extracted to before // the TPU computation. These ops are either connected to the inputs of the TPU // computation or other ops that can be extracted, and have no operands from // other ops in the TPU computation that cannot be extracted. llvm::SmallVector FindOutsideCompiledOpsAtHead( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector head_outside_compiled_ops; @@ -127,6 +140,24 @@ llvm::SmallVector FindOutsideCompiledOpsAtHead( if (!HasOutsideCompilationAttribute(&cluster_op)) continue; // An outside compiled op can be extracted if its operands are not from // other ops in the cluster that cannot be extracted. + + // Check if the side effecting op right before this side effecting op, if + // it is side effecting, can be head extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be head extracted. + // TODO(lyandy): Remove special handling of embedding ops. Currently the IR + // is in a topological sort order and depending on that ordering, embedding + // ops may prevent other ops from being head extracted. + auto predecessors = analysis.DirectControlPredecessors(&cluster_op); + if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) { + bool skip = false; + for (Operation* predecessor : llvm::reverse(predecessors)) { + if (IsEmbeddingOp(predecessor)) continue; + skip = !head_outside_compiled_ops.contains(predecessor); + break; + } + if (skip) continue; + } + auto walk_result = cluster_op.walk([&](Operation* op) { for (Value operand : op->getOperands()) { Operation* operand_op = GetOpOfValue(operand); @@ -168,11 +199,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, // Extracts and move outside compiled ops that have no dependencies in the // cluster to before the cluster. mlir::LogicalResult LiftHeadOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - tf_device::ClusterOp cluster, std::string* host_device, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster, + std::string* host_device, bool* cluster_updated) { llvm::SmallVector head_outside_compiled_ops = - FindOutsideCompiledOpsAtHead(cluster); + FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster); if (head_outside_compiled_ops.empty()) return success(); if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster, host_device))) @@ -191,9 +222,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps( // TPU computation or other ops that can be extracted, and have no results used // by other ops in the TPU computation that cannot be extracted. void FindOutsideCompiledOpsAtTailAndClusterResults( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster, llvm::SmallVectorImpl* tail_outside_compiled_ops, llvm::SmallVectorImpl* cluster_results) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector tail_outside_compiled_ops_set; Operation* terminator = cluster.GetBody().getTerminator(); @@ -205,6 +239,24 @@ void FindOutsideCompiledOpsAtTailAndClusterResults( for (Operation& cluster_op : cluster_ops) { if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // Check if the side effecting op right after this side effecting op, if + // it is side effecting, can be tail extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be tail extracted. + // TODO(lyandy): Remove special handling of embedding ops. Currently the IR + // is in a topological sort order and depending on that ordering, embedding + // ops may prevent other ops from being tail extracted. + auto successors = analysis.DirectControlSuccessors( + &cluster_op, [&terminator](Operation* op) { return op != terminator; }); + if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) { + bool skip = false; + for (Operation* successor : successors) { + if (IsEmbeddingOp(successor)) continue; + skip = !tail_outside_compiled_ops_set.contains(successor); + break; + } + if (skip) continue; + } + llvm::SmallVector results_to_forward; bool can_be_extracted = llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { @@ -293,13 +345,14 @@ tf_device::ClusterOp UpdateClusterResults( // Extracts and move outside compiled ops that do not create dependencies in the // cluster to after the cluster. mlir::LogicalResult LiftTailOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - std::string host_device, tf_device::ClusterOp* cluster, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, std::string host_device, + tf_device::ClusterOp* cluster, bool* cluster_updated) { llvm::SmallVector tail_outside_compiled_ops; llvm::SmallVector cluster_results; - FindOutsideCompiledOpsAtTailAndClusterResults( - *cluster, &tail_outside_compiled_ops, &cluster_results); + FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster, + &tail_outside_compiled_ops, + &cluster_results); if (tail_outside_compiled_ops.empty()) return success(); if (host_device.empty()) @@ -331,7 +384,8 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder, for (auto result : llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) { Value cluster_terminator_operand = std::get<0>(result); - if (cluster.getOperation()->isProperAncestor( + if (cluster_terminator_operand.getDefiningOp() && + cluster.getOperation()->isProperAncestor( cluster_terminator_operand.getDefiningOp())) { new_cluster_results.push_back(cluster_terminator_operand); new_cluster_result_types.push_back(cluster_terminator_operand.getType()); @@ -364,6 +418,7 @@ struct TPUExtractHeadTailOutsideCompilation }; void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + auto& side_effect_analysis = getAnalysis(); // Get runtime devices information from the closest parent module. auto module = getOperation(); mlir::TF::RuntimeDevices devices; @@ -378,10 +433,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { for (tf_device::ClusterOp cluster : clusters) { std::string host_device; bool cluster_updated = false; - if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, - &host_device, &cluster_updated)) || - failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, - &cluster, &cluster_updated))) + if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis, + devices, cluster, &host_device, + &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis, + devices, host_device, &cluster, + &cluster_updated))) return signalPassFailure(); if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index cbea4ae6544..303b69c2730 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -17,11 +17,21 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -77,51 +87,315 @@ struct TPUExtractOutsideCompilation void runOnOperation() override; }; -// Collects and clusters ops in `block` with the same `_xla_outside_compilation` -// attribute into `clusters` This returns an error if a -// `_xla_outside_compilation` attribute of an op is empty. -LogicalResult CollectAndGroupOutsideClusterOps(Block* block, - OutsideClusterMap* clusters) { - for (Operation& op : *block) { - if (auto attr = op.getAttrOfType(kXlaOutsideCompilationAttr)) { - if (attr.getValue().empty()) - return op.emitError() - << "attribute '" << kXlaOutsideCompilationAttr << "' is empty"; +// Holds information about control flow operations that wrap outside compiled +// op. Currently only tf.IfRegion and tf.WhileRegion ops are supported. +class ControlFlowStackInfo { + public: + enum ControlFlowBranchType { kIfThen, kIfElse, kWhileCond, kWhileBody }; - auto it = clusters->try_emplace(attr.getValue()); - it.first->getSecond().push_back(&op); + explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op) + : callsite_op_(wrapping_op) { + if (auto control_flow_op = llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.then_branch() == parent_region) { + type_ = ControlFlowBranchType::kIfThen; + } else { + type_ = ControlFlowBranchType::kIfElse; + } + } else if (auto control_flow_op = + llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.cond() == parent_region) { + type_ = ControlFlowBranchType::kWhileCond; + } else { + type_ = ControlFlowBranchType::kWhileBody; + } + } else { + assert(false); } } - return success(); + Value GetIfPredicateValue() { + auto if_op = llvm::cast(callsite_op_); + return if_op.cond(); + } + + ControlFlowBranchType GetBranchType() const { return type_; } + + Operation* GetCallSiteOp() const { return callsite_op_; } + + private: + ControlFlowBranchType type_; + + // `this` does not hold ownership of `callsite_op_`. + Operation* callsite_op_; +}; + +// Returns a list of ControlFlowStackInfo that represents a stack of control +// flow operations that wraps `op`. +llvm::SmallVector GetControlFlowStackForOp( + tf_device::ClusterOp tpu_cluster, Operation* op) { + assert(tpu_cluster.getOperation()->isProperAncestor(op)); + + llvm::SmallVector controlflow_stack; + Operation* op_in_stack = op; + while (op_in_stack != tpu_cluster.getOperation()) { + auto parent_op = op_in_stack->getParentOp(); + if (llvm::isa(parent_op)) { + controlflow_stack.insert(controlflow_stack.begin(), + ControlFlowStackInfo(parent_op, op_in_stack)); + } + op_in_stack = parent_op; + } + + return controlflow_stack; } -// Moves `cluster_ops` to associated `launch_op` body. -void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op, - llvm::ArrayRef cluster_ops) { - MLIRContext* context = launch_op.getContext(); - Operation* terminator = launch_op.GetBody().getTerminator(); +// Creates a IfRegionOp with `predicate` and then/else region with yield op and +// an empty block. +TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless, + Location loc, OpBuilder* builder) { + auto host_side_if = builder->create( + loc, llvm::SmallVector{}, predicate, is_stateless); + // Create empty then branch region. + auto& then_branch = host_side_if.then_branch(); + then_branch.push_back(new Block); + builder->setInsertionPointToEnd(&then_branch.front()); + builder->create(loc, /*operands=*/ArrayRef{}); + + // Create empty else branch region. + auto& else_branch = host_side_if.else_branch(); + else_branch.push_back(new Block); + builder->setInsertionPointToEnd(&else_branch.front()); + builder->create(loc, /*operands=*/ArrayRef{}); + return host_side_if; +} + +// Replicates tf.IfRegion op to host side computation. +Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, + llvm::StringRef outside_cluster_name, + Value compilation_key, OpBuilder* builder, + int* send_recv_counter) { + // Create XlaSendToHostOp to send predicate value from device to host. + OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint(); + auto if_callsite_op = + llvm::cast(controlflow_info.GetCallSiteOp()); + builder->setInsertionPoint(if_callsite_op); + + const auto predicate_send_recv_key = + llvm::formatv("if_predicate_channel_{0}_{1}", outside_cluster_name, + *send_recv_counter) + .str(); + *send_recv_counter += 1; + + auto predicate = if_callsite_op.cond(); + auto predicate_shape = predicate.getType(); + builder->create(if_callsite_op.getLoc(), predicate, + predicate_send_recv_key); + + // Create XlaRecvAtHostOp to receive predicate value from host. + builder->restoreInsertionPoint(insert_point); + auto recv_predicate_at_host = builder->create( + if_callsite_op.getLoc(), llvm::ArrayRef{predicate_shape}, + /*dynamic_key=*/compilation_key, + builder->getStringAttr(predicate_send_recv_key), + /*device_ordinal=*/builder->getI64IntegerAttr(0)); + + // Create host side if op. + return CloneEmptyIfWithPredicate(recv_predicate_at_host.getResult(0), + if_callsite_op.is_stateless(), + if_callsite_op.getLoc(), builder); +} + +// Creates a WhileRegionOp cond and body regions with yield op and +// an empty body. +TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, + uint64_t parallel_iterations, Location loc, + OpBuilder* builder) { + auto host_side_while = builder->create( + loc, /*output=*/ArrayRef{}, /*input=*/ArrayRef{}, + is_stateless, parallel_iterations); + + // Create empty else branch region. + auto& body = host_side_while.body(); + body.push_back(new Block); + builder->setInsertionPointToEnd(&body.front()); + builder->create(loc, /*operands=*/ArrayRef{}); + return host_side_while; +} + +// Replicates tf.WhileRegion op to host side computation. +Operation* ReplicateWhile(const ControlFlowStackInfo& controlflow_info, + llvm::StringRef outside_cluster_name, + Value compilation_key, OpBuilder* builder, + int* send_recv_counter) { + // Create XlaSendToHostOp to send cond region output from device to host. + OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint(); + auto while_callsite_op = + llvm::cast(controlflow_info.GetCallSiteOp()); + builder->setInsertionPoint(while_callsite_op.cond().front().getTerminator()); + + const auto condition_send_recv_key = + llvm::formatv("while_condition_channel_{0}_{1}", outside_cluster_name, + *send_recv_counter) + .str(); + *send_recv_counter += 1; + auto condition = + while_callsite_op.cond().front().getTerminator()->getOperand(0); + builder->create(while_callsite_op.getLoc(), condition, + condition_send_recv_key); + builder->restoreInsertionPoint(insert_point); + + auto host_side_while = CloneEmptyWhile( + while_callsite_op.is_stateless(), while_callsite_op.parallel_iterations(), + while_callsite_op.getLoc(), builder); + + // Create cond region and yield the condition from the device. + auto& cond = host_side_while.cond(); + cond.push_back(new Block); + builder->setInsertionPointToEnd(&cond.front()); + auto recv_condition_at_host = builder->create( + while_callsite_op.getLoc(), llvm::ArrayRef{condition.getType()}, + /*dynamic_key=*/compilation_key, + builder->getStringAttr(condition_send_recv_key), + /*device_ordinal=*/builder->getI64IntegerAttr(0)); + builder->create(while_callsite_op.getLoc(), + recv_condition_at_host.getResults()); + + return host_side_while; +} + +// TODO(b/157054714): Use a better abstraction instead of +// _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp. +// Creates a compilation key as placeholder. A placeholder compilation cache key +// is created because it is a required input to _XlaRecvAtHost and +// _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU +// cluster that contains the outside compiled ops. This placeholder should be +// replaced by the TPU cluster _TPUCompileMlir in a subsequent pass. +Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) { + auto result_type = + RankedTensorType::get({2}, builder->getType()); + return builder->create( + loc, /*program=*/result_type, llvm::ArrayRef{}); +} + +// Replicates the control flow operations that wraps outside compiled ops to +// `destination_block`. +Operation* ReplicateControlFlowStack( + llvm::StringRef outside_cluster_name, + const llvm::SmallVectorImpl& stack_info, + tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key, + Block* destination_block, int* send_recv_counter) { + assert(stack_info.size()); + OpBuilder builder = OpBuilder::atBlockTerminator(destination_block); + Operation* previous_replicated_controlflow_op = nullptr; + for (const auto& controlflow_stack_info : stack_info) { + // Create control flow op given provided insertion point and + // ControlFlowStackInfo. + if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateIf(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto if_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); + + // Update the insertion point to proper region inside the newly created + // control flow op. + if (type == ControlFlowStackInfo::kIfThen) { + builder.setInsertionPoint(&if_op.then_branch().front().front()); + } else { + builder.setInsertionPoint(&if_op.else_branch().front().front()); + } + } else if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateWhile(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto while_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); + if (type == ControlFlowStackInfo::kWhileCond) { + builder.setInsertionPoint(&while_op.cond().front().front()); + } else { + builder.setInsertionPoint(&while_op.body().front().front()); + } + } + } + + // Return operation which should be used to as the insertion point to create + // send/recv ops. + if (auto inner_most_if = + llvm::dyn_cast(previous_replicated_controlflow_op)) { + auto inner_most_controlflow_stack = stack_info.back(); + if (inner_most_controlflow_stack.GetBranchType() == + ControlFlowStackInfo::kIfThen) { + return inner_most_if.then_branch().front().getTerminator(); + } else { + return inner_most_if.else_branch().front().getTerminator(); + } + } else if (auto inner_most_while = llvm::dyn_cast( + previous_replicated_controlflow_op)) { + auto inner_most_controlflow_stack = stack_info.back(); + if (inner_most_controlflow_stack.GetBranchType() == + ControlFlowStackInfo::kWhileCond) { + return &inner_most_while.cond().front().front(); + } else { + return inner_most_while.body().front().getTerminator(); + } + } + return destination_block->getTerminator(); +} + +// Collects and clusters ops in `block` with the same `_xla_outside_compilation` +// attribute into `clusters` This returns an error if a +// `_xla_outside_compilation` attribute of an op is empty. +// TODO(b/163141763): Make sure ops inside control flow regions are not outside +// compiled if the entire control flow op is marked as outside compiled. +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { + auto walk_result = block->walk([&](Operation* op) { + if (auto attr = op->getAttrOfType(kXlaOutsideCompilationAttr)) { + if (attr.getValue().empty()) { + op->emitError() << "attribute '" << kXlaOutsideCompilationAttr + << "' is empty"; + return WalkResult::interrupt(); + } + + auto it = clusters->try_emplace(attr.getValue()); + it.first->getSecond().push_back(op); + } + return WalkResult::advance(); + }); + + return failure(walk_result.wasInterrupted()); +} + +// Moves `cluster_ops` before `op`. +void MoveOutsideClusterOpsBeforeOp(Operation* op, + llvm::ArrayRef cluster_ops, + MLIRContext* context) { for (Operation* cluster_op : cluster_ops) { // Remove `_xla_outside_compilation` and `device` attribute from ops in the // cluster as that information will be present in the `launch_op`. cluster_op->removeAttr( Identifier::get(kXlaOutsideCompilationAttr, context)); cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); + cluster_op->moveBefore(op); } } -// Creates a `tf_device::LaunchOp` to wrap cluster ops. +// Creates a `tf_device.launch` to wrap cluster ops. tf_device::LaunchOp CreateLaunchOpForOutsideCluster( OpBuilder* builder, Operation* last_cluster_op, llvm::StringRef host_device) { // An empty string placeholder is used for the device as that will be later // populated with the device of the associated TPUReplicateMetadata op. - llvm::SmallVector result_types; auto launch_op = builder->create( last_cluster_op->getLoc(), builder->getStringAttr(host_device), - result_types); + /*result_types=*/ArrayRef{}); launch_op.body().push_back(new Block); @@ -133,21 +407,61 @@ tf_device::LaunchOp CreateLaunchOpForOutsideCluster( return launch_op; } -// Extracts all externally provided operands of `cluster_ops`. +// Extracts all externally provided operands of `host_cluster_ops`. llvm::SmallSetVector GetExternalOperands( - llvm::ArrayRef cluster_ops) { + tf_device::ClusterOp tpu_cluster, + llvm::ArrayRef host_cluster_ops) { llvm::SmallSetVector external_values; - for (Operation* op : cluster_ops) { - for (Value v : op->getOperands()) { - Operation* defining_op = v.getDefiningOp(); - if (!defining_op) continue; - bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { - return defining_op == cluster_op; - }); + for (Operation* host_cluster_op : host_cluster_ops) { + auto cluster_op_parent_region = host_cluster_op->getParentRegion(); + host_cluster_op->walk([&](Operation* op) { + auto region = op->getParentRegion(); - if (is_external) external_values.insert(v); - } + if (region == cluster_op_parent_region) { + // For op operands, add operand defining ops, if they are not included + // in `host_cluster_ops`. + for (Value v : op->getOperands()) { + Operation* defining_op = v.getDefiningOp(); + bool is_external = false; + if (defining_op) { + is_external = + llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) { + return defining_op == cluster_op; + }); + } else { + if (auto block_arg = v.dyn_cast()) { + if (block_arg.getParentRegion() == cluster_op_parent_region) + is_external = true; + } + } + if (is_external) external_values.insert(v); + } + } else { + llvm::SetVector external_captured_inputs; + visitUsedValuesDefinedAbove(*region, *region, [&](OpOperand* operand) { + Region* operand_defined_region = operand->get().getParentRegion(); + if (!tpu_cluster.body().isAncestor(operand_defined_region)) return; + // If the host_cluster_op is regional control flow (if, while), + // then check if the operand_defined_region is an ancestor of the + // control flow regions. + if (auto if_op = llvm::dyn_cast(host_cluster_op)) { + if (if_op.then_branch().isAncestor(operand_defined_region) || + if_op.else_branch().isAncestor(operand_defined_region)) + return; + } + if (auto while_op = + llvm::dyn_cast(host_cluster_op)) { + if (while_op.cond().isAncestor(operand_defined_region) || + while_op.body().isAncestor(operand_defined_region)) + return; + } + external_captured_inputs.insert(operand->get()); + }); + external_values.insert(external_captured_inputs.begin(), + external_captured_inputs.end()); + } + }); } return external_values; @@ -212,34 +526,42 @@ TF::_XlaHostComputeMlirOp CreateHostCompute( } void MoveOutsideCompiledOps( - tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name, - tf_device::LaunchOp host_launch_op, llvm::ArrayRef cluster_ops, + ModuleOp module, tf_device::ClusterOp tpu_cluster, + llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, + llvm::ArrayRef cluster_ops, const llvm::SmallSetVector& external_inputs, llvm::ArrayRef external_outputs) { + // Since ops in `cluster_ops` do not cross function/control flow boundary, it + // is sufficient to identify the control flow that wraps `cluster_ops` by + // looking at any arbitary op inside `cluster_ops`. + auto controlflow_stack = + GetControlFlowStackForOp(tpu_cluster, cluster_ops.front()); + + Value compilation_key; + if (!controlflow_stack.empty() || !external_inputs.empty() || + !external_outputs.empty()) { + OpBuilder builder(&host_launch_op.GetBody().front()); + compilation_key = + CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); + } + + Operation* insertion_op = nullptr; + if (controlflow_stack.empty()) { + insertion_op = host_launch_op.GetBody().getTerminator(); + } else { + int send_recv_counter = 0; + insertion_op = ReplicateControlFlowStack( + outside_cluster_name, controlflow_stack, tpu_cluster, module, + compilation_key, &host_launch_op.GetBody(), &send_recv_counter); + } + + MLIRContext* context = host_launch_op.getContext(); if (external_inputs.empty() && external_outputs.empty()) { - MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); return; } - OpBuilder builder(host_launch_op.GetBody().getTerminator()); - auto result_type = - RankedTensorType::get({}, builder.getType()); - - std::string txt_metadata; - std::string txt_module; - // TODO(b/157054714): Use a better abstraction instead of _TPUCompileMlirOp - // and _XlaRecvAtHostOp and _XlaSendFromHostOp. - - // A placeholder compilation cache key is created because it is a required - // input to _XlaRecvAtHost and _XlaSendFromHost but the _TPUCompileMlir has - // not yet been created for the TPU cluster that contains the outside compiled - // ops. This placeholder should be replaced by the TPU cluster _TPUCompileMlir - // in a subsequent pass. - auto compilation_key = - builder.create( - tpu_cluster.getLoc(), /*program=*/result_type, - llvm::ArrayRef{}); - + OpBuilder builder(insertion_op); llvm::SmallVector host_output_types; for (const auto& external_input : external_inputs) host_output_types.push_back(external_input.getType()); @@ -250,6 +572,7 @@ void MoveOutsideCompiledOps( std::string retvals_communication_key = llvm::formatv("host_compute_channel_{0}_retvals", outside_cluster_name) .str(); + auto recv_at_host = builder.create( tpu_cluster.getLoc(), host_output_types, /*dynamic_key=*/compilation_key, @@ -259,9 +582,9 @@ void MoveOutsideCompiledOps( auto host_compute = CreateHostCompute( &builder, tpu_cluster, cluster_ops, external_inputs, external_outputs, args_communication_key, retvals_communication_key); - MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); - builder.setInsertionPoint(host_launch_op.GetBody().getTerminator()); + builder.setInsertionPoint(insertion_op); builder.create( tpu_cluster.getLoc(), external_outputs, /*dynamic_key=*/compilation_key, @@ -279,7 +602,8 @@ void MoveOutsideCompiledOps( // Creates a `parallel_execute` op in place of launch with 'clusters` and // 'launch` as regions. -void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster, +void CreateParallelExecuteFromOutsideClusters(ModuleOp module, + tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters, llvm::StringRef host_device) { OpBuilder builder(tpu_cluster); @@ -295,18 +619,18 @@ void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster, Block& outside_block = parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); + builder.setInsertionPointToEnd(&outside_block); tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( &builder, cluster_ops.back(), host_device); // Determine if there are any inputs that are provided out of cluster. - auto external_inputs = GetExternalOperands(cluster_ops); + auto external_inputs = GetExternalOperands(tpu_cluster, cluster_ops); auto external_outputs = GetExternalOutputs(cluster_ops); - MoveOutsideCompiledOps(tpu_cluster, cluster.value().getFirst(), + MoveOutsideCompiledOps(module, tpu_cluster, cluster.value().getFirst(), host_launch_op, cluster_ops, external_inputs, external_outputs); - builder.setInsertionPointToEnd(&outside_block); builder.create(tpu_cluster.getLoc(), ArrayRef{}); @@ -352,7 +676,8 @@ void TPUExtractOutsideCompilation::runOnOperation() { std::string host_device; tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster, &host_device); - CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters, + + CreateParallelExecuteFromOutsideClusters(module, tpu_cluster, clusters, host_device); return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc new file mode 100644 index 00000000000..32b1eb340d6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +// This pass removes Identity/IdentityN ops from the TPU computation and +// reachable functions. +// TODO(lyandy): Remove this pass once resource op lifting is migrated to use +// resource alias analysis and support region based control flow. Removing +// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops +// are used to propagate such information. + +struct TPUIdentityPruning + : public PassWrapper> { + void runOnOperation() override; +}; + +// Collects all reachable functions (via call ops) from a given region. +SmallVector CollectReachableFunctions(Region& region) { + llvm::SmallPtrSet reachable_funcs; + + auto collect_reachable_funcs = + [&reachable_funcs](Region& src, SmallVectorImpl& funcs_to_visit) { + src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) { + auto func = dyn_cast_or_null(call_op.resolveCallable()); + if (func && reachable_funcs.insert(func).second) + funcs_to_visit.push_back(func); + }); + }; + + SmallVector funcs_to_visit; + collect_reachable_funcs(region, funcs_to_visit); + + while (!funcs_to_visit.empty()) { + SmallVector new_funcs_to_visit; + for (FuncOp func_to_visit : funcs_to_visit) { + if (!func_to_visit.getCallableRegion()) continue; + collect_reachable_funcs(*func_to_visit.getCallableRegion(), + new_funcs_to_visit); + } + funcs_to_visit.swap(new_funcs_to_visit); + } + + return llvm::to_vector<4>(reachable_funcs); +} + +// Removes Identity/IdentityN ops from a region and forwards its operands to its +// results. +void RemoveIdentityFromRegion(Region& region) { + region.walk([](Operation* op) { + if (isa(op)) { + op->replaceAllUsesWith(op->getOperands()); + op->erase(); + } + }); +} + +void TPUIdentityPruning::runOnOperation() { + SmallVector clusters; + getOperation().walk( + [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); + + for (tf_device::ClusterOp cluster : clusters) { + RemoveIdentityFromRegion(cluster.body()); + auto reachable_funcs = CollectReachableFunctions(cluster.body()); + for (FuncOp reachable_func : reachable_funcs) + RemoveIdentityFromRegion(*reachable_func.getCallableRegion()); + } +} + +} // anonymous namespace + +std::unique_ptr> CreateTPUIdentityPruningPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-identity-pruning", + "Removes Identity/IdentityN ops from the TPU computation"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc index be01b7644ea..900bdf6f519 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -22,6 +23,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -33,8 +35,10 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; struct TPUOutsideCompilationCluster - : public PassWrapper { - void runOnFunction() override; + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TPUOutsideCompilationCluster, TF::SideEffectAnalysis> { + void runOnFunction(FuncOp func, + const TF::SideEffectAnalysis::Info& side_effect_analysis); }; // Represents an outside compiled cluster. All ops that are added to the same @@ -44,72 +48,86 @@ class OutsideCompiledCluster { explicit OutsideCompiledCluster(int number) : cluster_name_(llvm::formatv("cluster{0}", number).str()) {} - // Attempts to add an op to this cluster. - // This function requires all ops to be added before their uses. - bool AddOp(Operation* op) { + // Attempts to add an op to this cluster. Ops can be grouped to the same + // cluster if they have data dependency and are inside the same block. + bool AddOp(Operation* op, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { // Check if the op is safe to add before adding it. - bool add = IsSafeToAdd(op); - if (add) { - // Set the ops kXlaOutsideCompilationAttr to the cluster name. + if (IsSafeToAdd(op, side_effect_analysis)) { op->setAttr(kXlaOutsideCompilationAttr, StringAttr::get(cluster_name_, op->getContext())); - - // Since we are adding the op to the cluster, the op is no longer - // considered a user of this cluster. - users_.erase(op); + host_cluster_ops_.insert(op); + return true; } - - // Add this op's users to the cluster users. - users_.insert(op->user_begin(), op->user_end()); - return add; + return false; } private: // Checks if it is safe for an op to be merged into this cluster. - bool IsSafeToAdd(Operation* op) { + bool IsSafeToAdd(Operation* op, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { + if (closed_) return false; // If the op is not marked for outside compilation it doesn't belong in a // cluster. - if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) + if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) { + auto successors = side_effect_analysis.DirectControlSuccessors(op); + // If non outside compiled op with side effect successors is encountered, + // close this cluster to additions so that no cluster cyclic dependencies + // can be created. + if (!successors.empty()) { + closed_ = true; + } return false; - - // Checks to see if the op's operands are related to this - // clusters users. If they are related, then there is an op between this - // op and the cluster. Since ops are added before their uses, there - // is no way for the op in-between to ever be added to this cluster - // therefore there is no way this op can ever be added to the cluster. - for (const Value& value : op->getOperands()) { - Operation* op_operand = value.getDefiningOp(); - if (op_operand && users_.find(op_operand) != users_.end()) return false; } - return true; + + if (host_cluster_ops_.empty()) return true; + + // Checks to see if there is data dependency between ops in + // `host_cluster_ops_` and `op`. + const bool contains_data_dependency = llvm::any_of( + op->getUsers(), + [&](Operation* user) { return host_cluster_ops_.contains(user); }); + + const bool inside_same_block = + llvm::all_of(host_cluster_ops_, [&](Operation* op_in_cluster) { + return op_in_cluster->getBlock() == op->getBlock(); + }); + + return inside_same_block && contains_data_dependency; } - // users_ stores the direct and indirect users of the outside compiled ops in - // this cluster. It does NOT store the outside compiled ops that are a part - // of this cluster that will be collectively extracted and run on the cpu. - // users_ is consulted when attempting to add a new outside compiled to the - // cluster. If the new op's operand(s) are already in users_, it means that - // the operand(s) were not added to the cluster so it is not safe to add the - // new op to the cluster either. - llvm::SmallPtrSet users_; + // `host_cluster_op_` stores a set of ops that will be grouped and computed + // on host as single XlaHostCompute op. An outside compiled op can be grouped + // to a single cluster if it has data dependency to another op already in the + // cluster. + llvm::SmallPtrSet host_cluster_ops_; std::string cluster_name_; + bool closed_ = false; // Cluster is closed to further additions. }; -void TPUOutsideCompilationCluster::runOnFunction() { +void TPUOutsideCompilationCluster::runOnFunction( + FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) { llvm::SmallVector clusters; int cluster_counter = 0; - getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { - for (Operation& op : tpu_cluster.GetBody()) { + func.walk([&](tf_device::ClusterOp tpu_cluster) { + llvm::SmallVector tpu_cluster_ops; + tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size()); + + tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); }); + + // In order to cluster ops feeding results to the same operation, traverse + // the ops in reverse order. + for (Operation* op : llvm::reverse(tpu_cluster_ops)) { // Try to add the op to existing clusters. bool added = false; for (auto& cluster : clusters) - if ((added = cluster.AddOp(&op))) break; + if ((added = cluster.AddOp(op, side_effect_analysis))) break; // If the op cannot be added to existing clusters, create a new cluster. if (!added) { OutsideCompiledCluster new_cluster(cluster_counter++); - new_cluster.AddOp(&op); + new_cluster.AddOp(op, side_effect_analysis); clusters.push_back(new_cluster); } } @@ -118,7 +136,7 @@ void TPUOutsideCompilationCluster::runOnFunction() { } // anonymous namespace -std::unique_ptr> +std::unique_ptr> CreateTPUOutsideCompilationClusterPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc new file mode 100644 index 00000000000..45773a128fd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +// A pass that moves `tf.AssignVariableOp` into a `tf_device.parallel_execute` +// region if the `tf.AssignVariableOp` is the only consumer of a +// `tf_device.parallel_execute` result. This will allow +// TPUMergeVariablesWithExecute to merge resource writes without special +// handling for `tf_device.parallel_execute`. +struct TPUParallelExecuteSinkResourceWrite + : public PassWrapper { + void runOnFunction() override; +}; + +// Finds an AssignVariableOp that can be moved into the parallel_execute region. +// These AssignVariableOps must be the only consumer of the respective +// parallel_execute result, and the resource handle producer must be from an op +// before or above the parallel_execute. +TF::AssignVariableOp GetSingleUseResourceWrite( + tf_device::ParallelExecuteOp parallel_execute, Value result) { + if (!result.hasOneUse()) return nullptr; + + OpOperand& use = *result.getUses().begin(); + auto assign_var = dyn_cast(use.getOwner()); + if (!assign_var) return nullptr; + + if (use.get() != assign_var.value()) return nullptr; + + auto* resource_handle_op = assign_var.resource().getDefiningOp(); + if (resource_handle_op == parallel_execute) return nullptr; + + if (resource_handle_op && + resource_handle_op->getBlock() == + parallel_execute.getOperation()->getBlock() && + parallel_execute.getOperation()->isBeforeInBlock(resource_handle_op)) + return nullptr; + + return assign_var; +} + +// Finds AssignVariableOps that can be moved into a parallel_execute region and +// moves them. Leftover parallel_execute results that were used by the +// such AssignVariableOp are also pruned. +void SinkResourceWritesIntoParallelExecute( + tf_device::ParallelExecuteOp parallel_execute) { + bool rewrite = false; + const int num_regions = parallel_execute.getNumRegions(); + llvm::SmallVector results_to_remap; + + // Go through each region and find AssignVariableOps that can be moved into + // the parallel_execute region. Result indices by region index are collected, + // so they can be removed afterwards. + llvm::SmallVector, 4> results_to_remove_by_region; + results_to_remove_by_region.resize(num_regions); + for (int i = 0; i < num_regions; ++i) { + Block& block = parallel_execute.GetRegionBlockWithIndex(i); + auto results = parallel_execute.GetRegionOutputs(i); + auto& results_to_remove = results_to_remove_by_region[i]; + results_to_remove.reserve(results.size()); + Operation* terminator = block.getTerminator(); + for (auto result : llvm::enumerate(results)) { + TF::AssignVariableOp assign_var = + GetSingleUseResourceWrite(parallel_execute, result.value()); + if (!assign_var) { + results_to_remap.push_back(result.value()); + continue; + } + + // Move AssignVariableOp and update the value to be written to the + // resource variable to be the non forwarded value from within the + // parallel_execute region. + assign_var.getOperation()->moveBefore(terminator); + assign_var.valueMutable().assign(terminator->getOperand(result.index())); + results_to_remove.push_back(result.index()); + } + + rewrite |= !results_to_remove.empty(); + } + + if (!rewrite) return; + + // Remove leftover unused results (terminator operands) from moving + // AssignVariabeOps into the parallel_execute region. + for (auto results_to_remove : llvm::enumerate(results_to_remove_by_region)) { + Block& block = + parallel_execute.GetRegionBlockWithIndex(results_to_remove.index()); + Operation* terminator = block.getTerminator(); + for (int index_to_remove : llvm::reverse(results_to_remove.value())) + terminator->eraseOperand(index_to_remove); + } + + // Replace old parallel_execute with new parallel_execute by moving the + // regions to a new parallel_execute and remapping the results. + llvm::SmallVector new_result_types; + new_result_types.reserve(results_to_remap.size()); + for (Value old_result : results_to_remap) + new_result_types.push_back(old_result.getType()); + + OpBuilder builder(parallel_execute); + auto new_parallel_execute = builder.create( + parallel_execute.getLoc(), num_regions, new_result_types); + + for (auto region : llvm::zip(new_parallel_execute.getRegions(), + parallel_execute.getRegions())) + std::get<0>(region)->takeBody(*std::get<1>(region)); + + for (auto result : + llvm::zip(results_to_remap, new_parallel_execute.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + parallel_execute.erase(); +} + +void TPUParallelExecuteSinkResourceWrite::runOnFunction() { + llvm::SmallVector parallel_executes; + getFunction().walk([&](tf_device::ParallelExecuteOp parallel_execute) { + parallel_executes.push_back(parallel_execute); + }); + + for (tf_device::ParallelExecuteOp parallel_execute : parallel_executes) + SinkResourceWritesIntoParallelExecute(parallel_execute); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUParallelExecuteSinkResourceWritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-parallel-execute-sink-resource-write", + "Moves tf.AssignVariableOp consumers of tf_device.parallel_execute into " + "tf_device.parallel_execute regions"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc new file mode 100644 index 00000000000..cccd528da1d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc @@ -0,0 +1,140 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TFTPU { + +// A pass that finds TPU clusters with write only resource access and adds an +// associated resource read, so the resource can later be fused into TPUExecute. +namespace { +struct TPUResourceReadForWrite + : public PassWrapper> { + void runOnOperation() override; +}; + +// Helper struct holding a resource value and its associated type. +struct ResourceValueAndSubtype { + Value resource; + Type subtype; +}; + +// Finds resource handle and type for result if result writes to a resource. +ResourceValueAndSubtype GetResourceWriteResult( + tf_device::ClusterFuncOp cluster_func, Value result) { + ResourceValueAndSubtype resource; + if (!result.hasOneUse()) return resource; + Operation* result_user = *result.getUsers().begin(); + auto assign_var = dyn_cast(result_user); + if (!assign_var) return resource; + + auto handle = assign_var.resource(); + // Skip result if cluster writes to the same variable via multiple results. + for (Operation* handle_user : handle.getUsers()) { + if (handle_user == assign_var) continue; + auto assign_var_user = dyn_cast(handle_user); + if (!assign_var_user) continue; + if (assign_var_user.value().getDefiningOp() == cluster_func) + return resource; + } + + resource.resource = assign_var.resource(); + resource.subtype = assign_var.value().getType(); + return resource; +} + +// Checks if resource is read by TPU cluster. +bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func, + Value resource) { + for (Operation* resource_user : resource.getUsers()) + if (auto read = dyn_cast(resource_user)) + for (Operation* read_user : read.value().getUsers()) + if (read_user == cluster_func) return true; + + return false; +} + +void TPUResourceReadForWrite::runOnOperation() { + SmallVector cluster_funcs; + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { + cluster_funcs.push_back(cluster_func); + }); + + OpBuilder builder(&getContext()); + // Add resource reads for resource writes from TPU cluster where for such + // resources the TPU cluster does not read from. + for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) { + builder.setInsertionPoint(cluster_func); + + SmallVector read_operands; + for (Value result : cluster_func.getResults()) { + // TODO(lyandy): Update pass to use resource alias analysis. + auto resource_and_type = GetResourceWriteResult(cluster_func, result); + if (!resource_and_type.resource) continue; + if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource)) + continue; + auto new_read = builder.create( + resource_and_type.resource.getLoc(), resource_and_type.subtype, + resource_and_type.resource); + read_operands.push_back(new_read.value()); + } + + if (read_operands.empty()) continue; + + // Update caller and function types with new read operands. + auto operands = llvm::to_vector<4>(cluster_func.getOperands()); + operands.append(read_operands.begin(), read_operands.end()); + + auto new_cluster_func = builder.create( + cluster_func.getLoc(), cluster_func.getResultTypes(), operands, + cluster_func.getAttrs()); + cluster_func.replaceAllUsesWith(new_cluster_func); + FuncOp func = cluster_func.getFunc(); + Block& block = func.front(); + for (Value read_operand : read_operands) + block.addArgument(read_operand.getType()); + + func.setType(FunctionType::get(block.getArgumentTypes(), + func.getCallableResults(), &getContext())); + cluster_func.erase(); + } +} + +} // namespace + +std::unique_ptr> CreateTPUResourceReadForWritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-resource-read-for-write", + "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes " + "with no reads"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index ca77feafc05..86aeec81150 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -25,7 +25,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -154,11 +154,8 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, symbol_table.insert(clone); } - // Serialize module and return. - { - llvm::raw_string_ostream os(*serialized_func_module); - module_for_func.get().print(os); - } + *serialized_func_module = + tensorflow::SerializeMlirModule(module_for_func.get()); return success(); } @@ -409,12 +406,15 @@ Operation* BuildCompileOp( std::string txt_module; if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr; - auto result_type = + auto compilation_status_type = RankedTensorType::get({}, builder->getType()); + auto program_type = + RankedTensorType::get({2}, builder->getType()); auto compile_op = builder->create( - cluster_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ - llvm::SmallVector(num_cores_per_replica, result_type), + cluster_func.getLoc(), + /*compilation_status=*/compilation_status_type, /*program=*/ + llvm::SmallVector(num_cores_per_replica, program_type), compile_op_operands, txt_module, txt_metadata); return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op, @@ -598,9 +598,9 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // func @main(%arg0: tensor) { // %0 = "tf.Shape"(%arg0) : (tensor) -> tensor // %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} : -// (tensor) -> (tensor, tensor) +// (tensor) -> (tensor, tensor<2x!tf.string>) // %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} : -// (tensor, tensor) -> tensor +// (tensor, tensor<2x!tf.string>) -> tensor // return // } // @@ -624,9 +624,9 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} { // %1 = "tf.Shape"(%ri) : (tensor) -> tensor // %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} : -// (tensor) -> (tensor, tensor) +// (tensor) -> (tensor, tensor<2x!tf.string>) // %3 = "tf.TPUExecute"(%ri, %2#0) : -// (tensor, tensor) -> tensor +// (tensor, tensor<2x!tf.string>) -> tensor // tf_device.return %3 : tensor // } // return @@ -644,7 +644,7 @@ LogicalResult Rewrite( int num_replicas = 1; tf_device::ReplicateOp replicate = cluster_func.getParentOfType(); - if (replicate) num_replicas = replicate.n().getLimitedValue(); + if (replicate) num_replicas = replicate.n(); auto num_cores_per_replica_attr = cluster_func.getAttrOfType( tensorflow::kNumCoresPerReplicaAttr); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index 204a674e632..ecfd6b33503 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -54,6 +54,11 @@ namespace { constexpr char kDeviceAttr[] = "device"; typedef std::pair Conv2DWithBlockSize; +struct BlockArgumentInfo { + unsigned arg_num; + unsigned num_users; +}; + // A pass that applies automatic space to depth transform for the first or // frontier convolutions consume host inputs on TPU. // This is done by adding space to depth transform op after host input and @@ -108,7 +113,49 @@ struct TPUSpaceToDepthPass void runOnOperation() override; }; -// Handle padding before convolution for space to depth transform. +// Updates func argument type to have the updated input shape. +void UpdateFuncType(FuncOp func) { + auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes()); + auto result_types = + llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes()); + func.setType(FunctionType::get(arg_types, result_types, func.getContext())); +} + +void HandleFuncOp(Operation* op) { + auto func = llvm::cast(op); + UpdateFuncType(func); +} + +// Handles cast op between the first convolution and the block argument. +LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { + auto cast_input = cast_op.x(); + // Update input type. + auto transform_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); + cast_input.setType(transform_result_type); + auto block_arg = cast_input.dyn_cast(); + auto cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); + while (block_arg || cast_op_input) { + if (block_arg) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + block_arg = nullptr; + cast_op_input = nullptr; + } else { + auto cast_input = cast_op_input.x(); + // Update input type. + auto transform_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); + cast_input.setType(transform_result_type); + // Update block arg and cast_op_input. + block_arg = cast_input.dyn_cast(); + cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); + } + } + return success(); +} + +// Handles padding before convolution for space to depth transform. LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { auto ranked_type = op.input().getType().dyn_cast(); if (!ranked_type) return failure(); @@ -134,6 +181,10 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { pad_input_shape[0], pad_input_shape[1] / block_size, pad_input_shape[2] / block_size, pad_input_shape[3] * block_size * block_size}; + // Input of the pad op could be a cast op. + if (auto cast_op = dyn_cast_or_null(input.getDefiningOp())) + if (failed(HandleCast(cast_op, transform_shape))) return failure(); + auto transform_result_type = RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); input.setType(transform_result_type); @@ -141,7 +192,7 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { return success(); } -// Handle stride for the first convolution for the transform. +// Handles stride for the first convolution for the transform. void HandleConv2DStride(TF::Conv2DOp conv2d) { MLIRContext* context = conv2d.getContext(); SmallVector values = {1, 1, 1, 1}; @@ -153,7 +204,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) { conv2d.setAttr("strides", strides); } -// Transform input shape for the first convolution. +// Transforms input shape for the first convolution. void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { auto input = conv2d.input(); auto input_shape = input.getType().cast().getShape(); @@ -165,7 +216,7 @@ void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { input.setType(transform_result_type); } -// Add padding for convolution filter for space to depth transform. +// Adds padding for convolution filter for space to depth transform. TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, OpBuilder* builder, int32_t pad_h, int32_t pad_w) { @@ -185,7 +236,7 @@ TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, paddings_value); } -// Create reshape op for space to depth transform. +// Creates reshape op for space to depth transform. TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, Value input, OpBuilder* builder) { auto reshape_result_type = @@ -199,7 +250,7 @@ TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, input, reshape_value); } -// Create transpose op for shape to depth transform. +// Creates transpose op for shape to depth transform. TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) { SmallVector permutation = {0, 2, 1, 3, 4, 5}; auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32)); @@ -259,7 +310,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { conv2d.setOperand(1, final_reshape_op); } -// Create slice op for filter in back prop pass. +// Creates slice op for filter in back prop pass. TF::SliceOp GetSliceOpForConv2DBackPropFilter( ArrayRef old_filter_shape, Value input, OpBuilder* builder) { SmallVector slice_size(old_filter_shape.begin(), @@ -281,7 +332,7 @@ TF::SliceOp GetSliceOpForConv2DBackPropFilter( start_position, slice_size_op); } -// Transform Conv2DBackPropFilter for space to depth. +// Transforms Conv2DBackPropFilter for space to depth. void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, ArrayRef old_filter_shape, ArrayRef new_filter_shape, @@ -354,22 +405,6 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, backprop.replaceAllUsesWith(slice_op.getResult()); } -// Update func arugument type to have the updated input shape. -void UpdateFuncType(FuncOp func) { - llvm::SmallVector arg_types; - arg_types.reserve(func.getNumArguments()); - for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType()); - auto terminator = func.front().getTerminator(); - SmallVector result_types(terminator->operand_type_begin(), - terminator->operand_type_end()); - func.setType(FunctionType::get(arg_types, result_types, func.getContext())); -} - -void HandleFuncOp(Operation* op) { - auto func = llvm::cast(op); - UpdateFuncType(func); -} - // Checks if the input producer op is supported in this transform. Right now, we // only check if it is a host tf.IteratorGetNext. bool IsSupportedHostInputOp(Operation* op) { @@ -397,9 +432,8 @@ TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func, input_shape[3] * block_size * block_size}; auto transform_result_type = RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); - return builder.create(cluster_func.getLoc(), - transform_result_type, input, - APInt(64, block_size)); + return builder.create( + cluster_func.getLoc(), transform_result_type, input, block_size); } // Performs transformation for a non-replicated input. @@ -417,12 +451,13 @@ TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index, // supported case (thus transform happened). bool HandleHostReplicatedInputs(int64_t index, tf_device::ClusterFuncOp cluster_func, - int64_t replicate_arg_index, + BlockArgument block_arg, tf_device::ReplicateOp replicate, int32_t block_size) { + int64_t replicate_arg_index = block_arg.getArgNumber(); // We need to know the devices to copy to. if (!replicate.devices()) return false; - int64_t num_replicas = replicate.n().getZExtValue(); + int64_t num_replicas = replicate.n(); // Gets inputs at replicate_arg_index for each replica. auto inputs = replicate.getOperands() .drop_front(replicate_arg_index * num_replicas) @@ -439,6 +474,7 @@ bool HandleHostReplicatedInputs(int64_t index, BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape); replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), space_to_depth); + block_arg.setType(space_to_depth.getType()); } return true; } @@ -457,9 +493,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, // For a block argument, consider transforms only when it is a replicated // input (defining ops will be outside the replicate node). if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) { - HandleHostReplicatedInputs(input.index(), cluster_func, - block_arg.getArgNumber(), maybe_replicate, - block_size); + HandleHostReplicatedInputs(input.index(), cluster_func, block_arg, + maybe_replicate, block_size); } } else { // For an op output, consider transforms only when 1) there is no @@ -482,7 +517,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, } } -// Check if input shape of convolution is good for space to depth transform. +// Checks if input shape of convolution is good for space to depth transform. bool Conv2DInputShapeCanTransform(Value input) { auto ranked_type = input.getType().dyn_cast(); if (!ranked_type) return false; @@ -495,35 +530,59 @@ bool Conv2DInputShapeCanTransform(Value input) { return true; } -// Checks if a convoluton can apply SpaceToDepth transform. -// Only the first convolution in the graph whose batch size smaller than 8 -// and its input feature size smaller than 8 can be transformed. -Optional> GetConv2DInputArgNum(TF::Conv2DOp conv2d) { - if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { - return None; - } - auto conv2d_input = conv2d.input(); - if (auto block_arg = conv2d_input.dyn_cast()) { - if (!Conv2DInputShapeCanTransform(conv2d_input)) return None; - int num_users = +// Get block argument id and number of users for the input arg. +Optional GetBlockArgNum(Value arg) { + if (auto block_arg = arg.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(arg)) return None; + unsigned num_users = std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); - return std::make_pair(block_arg.getArgNumber(), num_users); + BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users}; + return block_arg_info; } + return None; +} - if (auto pad_op = llvm::dyn_cast(conv2d_input.getDefiningOp())) { - auto pad_input = pad_op.input(); - if (auto block_arg = pad_input.dyn_cast()) { - if (!Conv2DInputShapeCanTransform(pad_input)) return None; - int num_users = std::distance(block_arg.getUsers().begin(), - block_arg.getUsers().end()); - return std::make_pair(block_arg.getArgNumber(), num_users); +// Gets input block argument id and number of users for the input recursively. +// Current supported ops between convolution input and the block arguments are +// PadOp and CastOp. +Optional GetInputBlockArgNum(Value input) { + auto block_arg_num = GetBlockArgNum(input); + if (block_arg_num.hasValue()) return block_arg_num; + + Value next_input = input; + auto pad_op = dyn_cast_or_null(next_input.getDefiningOp()); + auto cast_op = dyn_cast_or_null(next_input.getDefiningOp()); + + while (pad_op || cast_op) { + if (pad_op) { + auto block_arg_num = GetBlockArgNum(pad_op.input()); + if (block_arg_num.hasValue()) return block_arg_num; + next_input = pad_op.input(); + } else { + auto block_arg_num = GetBlockArgNum(cast_op.x()); + if (block_arg_num.hasValue()) return block_arg_num; + next_input = cast_op.x(); } + pad_op = dyn_cast_or_null(next_input.getDefiningOp()); + cast_op = dyn_cast_or_null(next_input.getDefiningOp()); } return None; } -// Apply space to depth transform for the first convolution on TPU device. +// Checks if a convoluton can apply SpaceToDepth transform. +// Only the first convolution in the graph whose batch size smaller than 8 +// and its input feature size smaller than 8 can be transformed. +Optional GetConv2DInputArgNum(TF::Conv2DOp conv2d) { + if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { + return None; + } + // Current supported ops between convolution input and the block arguments are + // PadOp and CastOp. + return GetInputBlockArgNum(conv2d.input()); +} + +// Applies space to depth transform for the first convolution on TPU device. void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Check if input and filter type are RankedTensorType. auto input_tensor_type = @@ -563,8 +622,9 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { SmallVector new_filter_shape(filter_shape.begin(), filter_shape.end()); - // Rewrite Conv2DBackPropFilter after the first convolution. - for (Operation* user : conv2d.getOperation()->getUsers()) { + // Rewrite Conv2DBackPropFilter that is the user of first convolution's input. + if (!conv2d_input.getDefiningOp()) return; + for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) { if (auto backprop = dyn_cast(user)) { HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape, block_size); @@ -572,7 +632,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { } } -// Get block size that is equal to stride from spatial dimension +// Gets block size that is equal to stride from spatial dimension // from convolution. // Space to depth transform won't be triggered if block size <= 1. int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { @@ -608,7 +668,6 @@ void TPUSpaceToDepthPass::runOnOperation() { if (!device_func) return; TF::Conv2DOp first_conv; - Optional> input_shape; // A map maps block argument id to the convolutions consumes them. llvm::SmallDenseMap> argnum_and_convolutions; @@ -617,13 +676,13 @@ void TPUSpaceToDepthPass::runOnOperation() { // Find out the qualified convolutions and its block argument ids. auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) { - Optional> arg_num_and_num_users = + Optional arg_num_and_num_users = GetConv2DInputArgNum(conv2d); if (arg_num_and_num_users.hasValue()) { // Get block size for the first convolution. int64_t block_size = GetConv2DBlockSize(conv2d); - auto arg_num = arg_num_and_num_users.getValue().first; - auto num_users = arg_num_and_num_users.getValue().second; + auto arg_num = arg_num_and_num_users.getValue().arg_num; + auto num_users = arg_num_and_num_users.getValue().num_users; argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); argnum_num_users[arg_num] = num_users; return WalkResult::interrupt(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 3262b83fc94..0e4ef76a54c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -174,7 +174,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( assert(metadata_str && "Missing compilation metadata"); tensorflow::tpu::TPUCompileMetadataProto metadata; metadata.ParseFromString(std::string(metadata_str.getValue())); - int64_t num_replicas = replicate.n().getLimitedValue(); + int64_t num_replicas = replicate.n(); // Find the formattable operands of `execute`, which must be mirrored // variables (arguments of `replicate`), and must be pass-throughs from while // operands. @@ -264,7 +264,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp( tf_device::ReplicateOp replicate, ArrayRef new_inputs, const llvm::SmallDenseMap>& devices) { - int64_t num_replicas = replicate.n().getLimitedValue(); + int64_t num_replicas = replicate.n(); assert(new_inputs.size() == num_replicas); // As model parallelism is not yet supported, we assume that all ops are @@ -423,7 +423,7 @@ void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op, // Performs the transformation for a replicate op inside a while loop. void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, MLIRContext* context) { - int64_t num_replicas = replicate.n().getLimitedValue(); + int64_t num_replicas = replicate.n(); if (num_replicas == 1) return; tf_device::LaunchOp execute_launch; for (auto execute_launch_op : @@ -452,8 +452,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, !llvm::isa(compile_launch.GetBody().front())) return; - FuncOp body = while_op.body_func(); - FuncOp cond = while_op.cond_func(); + FuncOp body = while_op.body_function(); + FuncOp cond = while_op.cond_function(); // Analyze the formattable inputs. auto execute_arg_to_outer_args = @@ -537,9 +537,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, // Build a constant default key to specify that the unformatting should // transform the variables to the original format. builder.setInsertionPointAfter(while_op); - tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2}); + tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3}); default_key_tensor.vec()(0) = kDefaultShardingValue; default_key_tensor.vec()(1) = kDefaultShardingValue; + default_key_tensor.vec()(2) = kDefaultShardingValue; auto default_state_key = builder.create( while_op.getLoc(), tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 0a69987deb0..b65f07c39ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -43,6 +43,10 @@ namespace { class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< BreakUpIslands, TF::SideEffectAnalysis> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnFunction(FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 571d5e3e715..0445dbb698a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -80,46 +81,14 @@ constexpr char kInvalidExecutorGraphMsg[] = constexpr char kDeviceAttr[] = "tf.device"; constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; -bool IsLegalChar(char c, bool first_char) { - if (isalpha(c)) return true; - if (isdigit(c)) return true; - if (c == '.') return true; - if (c == '_') return true; - - // First character of a node name can only be a letter, digit, dot or - // underscore. - if (first_char) return false; - - if (c == '/') return true; - if (c == '-') return true; - - return false; -} - -// Convert characters in name that are considered illegal in TensorFlow Node -// name to '.'. -std::string LegalizeNodeName(llvm::StringRef name) { - assert(!name.empty() && "expected non-empty name"); - - std::string legalized_name; - bool first = true; - for (auto c : name) { - if (IsLegalChar(c, first)) { - legalized_name += c; - } else { - legalized_name += '.'; - } - first = false; - } - - return legalized_name; -} - // OpOrArgLocNameMapper that legalizes the returned name. class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { private: std::string GetName(OpOrVal op_or_val) override { - return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val)); + std::string name = OpOrArgLocNameMapper::GetName(op_or_val); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); + return name; } }; @@ -523,13 +492,14 @@ StatusOr> Exporter::Convert( if (index >= num_data_results) break; // TODO(jpienaar): If there is a result index specified, ensure only one // and that it matches the result index of the op. - std::string orig_name(output_names[index]); - auto tensor_id = ParseTensorName(orig_name); - auto name = LegalizeNodeName( - llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); + std::string name(output_names[index]); + auto tensor_id = ParseTensorName(name); + std::string tensor_id_node(tensor_id.node()); + assert(!tensor_id_node.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(tensor_id_node); // Ensure name does not get reused. - (void)exporter.op_to_name_.GetUniqueName(name); + (void)exporter.op_to_name_.GetUniqueName(tensor_id_node); } } @@ -537,8 +507,9 @@ StatusOr> Exporter::Convert( TF_RET_CHECK(input_names.size() == block.getNumArguments()); for (const auto& it : llvm::enumerate(function.getArguments())) { // TODO(lyandy): Update when changing feed/fetch import. - std::string orig_name(input_names[it.index()]); - std::string name = LegalizeNodeName(orig_name); + std::string name(input_names[it.index()]); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); auto tensor_id = ParseTensorName(name); TF_RET_CHECK(tensor_id.index() == 0) << "input port designation not supported"; @@ -726,7 +697,7 @@ Status Exporter::Convert(mlir::ModuleOp module, mlir::Identifier::get("main", module.getContext()); absl::optional entry_func; FunctionDefLibrary flib; - auto tf_dialect = module.getContext()->getRegisteredDialect("tf"); + auto tf_dialect = module.getContext()->getLoadedDialect("tf"); for (auto function : module.getOps()) { if (function.isExternal()) return errors::FailedPrecondition("External functions not supported"); @@ -799,7 +770,7 @@ StatusOr> ConvertMlirToGraphdef( stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( mlir::FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) { - Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); + Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf"); FunctionDefLibrary flib; TF_RETURN_IF_ERROR( Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 3ca06e5efa9..727831a6055 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -34,7 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -using stream_executor::port::StatusOr; // Sets type list attribute with the given `name` to the given `types`. If the // attribute already exists with a different value, returns an error. @@ -90,7 +89,7 @@ Status SetShapeAttribute(absl::string_view name, ContainerT shapes, // definitions and isn't a header file. #include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc" -// Collect all the unregistered attributes for an TF dialect operation. +// Collects all the unregistered attributes for an TF dialect operation. // Attributes "name" and "device" are not included because they are not part // of an TF op attributes. Status GetUnregisteredAttrs( @@ -123,17 +122,10 @@ Status GetUnregisteredAttrs( return Status::OK(); } -} // namespace - -StatusOr> ConvertTFDialectOpToNodeDef( - mlir::Operation* inst, llvm::StringRef name, - bool ignore_unregistered_attrs) { - // Use auto generated function to populate derived attribute. - // - // Note: This only populates derived attributes for TensorFlow ops that are - // generated using the TableGen. Manually defined ops should have all the - // attributes present as native MLIR op attributes. - +// Collects all attribute names to ignore in an MLIR operation when exporting to +// a TensorFlow NodeDef. +StatusOr> GetAttributesToIgnore( + mlir::Operation* inst, bool ignore_unregistered_attrs) { // The elements are owned by the MLIRContext. absl::flat_hash_set attrs_to_ignore; if (inst->isRegistered()) { @@ -162,15 +154,25 @@ StatusOr> ConvertTFDialectOpToNodeDef( attrs_to_ignore.insert(attr_name.data()); } - TF_ASSIGN_OR_RETURN(auto node_def, - GetOperationNodeDef(attrs_to_ignore, inst, name)); + return attrs_to_ignore; +} + +// Populates all derived attributes of a MLIR operation in a proto +// map. +Status PopulateDerivedAttributes(mlir::Operation* inst, + bool ignore_unregistered_attrs, + AttrValueMap* attributes) { + // Use auto generated function to populate derived attribute. + // + // Note: This only populates derived attributes for TensorFlow ops that are + // generated using the TableGen. Manually defined ops should have all the + // attributes present as native MLIR op attributes. // If the operation is not registered, we won't be able to infer any attribute if (inst->isRegistered()) { - TF_RETURN_WITH_CONTEXT_IF_ERROR( - PopulateDerivedAttrs(inst, node_def->mutable_attr()), - "When populating derived attrs for ", - inst->getName().getStringRef().str()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(PopulateDerivedAttrs(inst, attributes), + "When populating derived attrs for ", + inst->getName().getStringRef().str()); } // Here we only add the shapes for the leading values with ShapedType, @@ -185,10 +187,38 @@ StatusOr> ConvertTFDialectOpToNodeDef( mlir::TF::ResultShapeRange output_shapes = { mlir::TF::ResultShapeIterator(begin), mlir::TF::ResultShapeIterator(end)}; - TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", output_shapes, - node_def->mutable_attr())); + TF_RETURN_IF_ERROR( + SetShapeAttribute("_output_shapes", output_shapes, attributes)); } } + + return Status::OK(); +} + +} // namespace + +Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs, + AttrValueMap* attributes) { + TF_ASSIGN_OR_RETURN(auto attrs_to_ignore, + GetAttributesToIgnore(inst, ignore_unregistered_attrs)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertAttributes(inst->getAttrs(), attrs_to_ignore, attributes), + "while converting attributes for node: ", name.str()); + TF_RETURN_IF_ERROR( + PopulateDerivedAttributes(inst, ignore_unregistered_attrs, attributes)); + return Status::OK(); +} + +StatusOr> ConvertTFDialectOpToNodeDef( + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs) { + TF_ASSIGN_OR_RETURN(auto attrs_to_ignore, + GetAttributesToIgnore(inst, ignore_unregistered_attrs)); + TF_ASSIGN_OR_RETURN(auto node_def, + GetOperationNodeDef(attrs_to_ignore, inst, name)); + TF_RETURN_IF_ERROR(PopulateDerivedAttributes(inst, ignore_unregistered_attrs, + node_def->mutable_attr())); return node_def; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index a19ad1f2940..bd260171a86 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -18,12 +18,22 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { -// Converts an MLIR operation to TensorFlow NodeDef with given node name. This +// Extracts the attributes of a MLIR operation and populates the converted +// attributes in a proto map. +Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs, + AttrValueMap* attributes); + +// Converts a MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted to. If the // `ignore_unregistered_attrs` argument is set to true, the attributes which are // not in the op registry will be ignored. If the `ignore_unregistered_attrs` @@ -31,9 +41,9 @@ namespace tensorflow { // ShapedType for the leading values with ShapedType in the results of the // nodes. Set it to true if the returned NodeDef will be executed by the linked // TF Eager runtime. -stream_executor::port::StatusOr> -ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name, - bool ignore_unregistered_attrs); +StatusOr> ConvertTFDialectOpToNodeDef( + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 27385e81262..153c537589c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -64,6 +64,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -141,6 +142,13 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, return false; } +void LoadImporterDialects(mlir::MLIRContext& context) { + // Load dialects involved in the conversion + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + registry.loadAll(&context); +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -177,7 +185,8 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def, restrict_functionalization_to_tpu_nodes ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); } : NodeFilter{}; - return FunctionalizeControlFlow(graph, flib_def, node_filter); + return FunctionalizeControlFlow(graph, flib_def, node_filter, + /*include_functions=*/true); } // Stateful helper class to import a TensorFlow model into an MLIR Module. @@ -1934,22 +1943,18 @@ Status ImporterBase::ConvertNode(const Node& node) { } } - // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add - // the differentiating attribute. - if (node.IsIfNode()) { - result.name = mlir::OperationName(get_full_op_name("If"), context_); - mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf"); + auto composite_control_flow_op = [&](const std::string& name) { + result.name = mlir::OperationName(get_full_op_name(name), context_); + bool stateless = absl::StartsWith(node_type_name, "Stateless"); + mlir::BoolAttr val = builder_.getBoolAttr(stateless); result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - } + }; - // Map While and StatelessWhile op in TensorFlow to the common While op in - // MLIR and add the differentiating attribute. - if (node.IsWhileNode()) { - result.name = mlir::OperationName(get_full_op_name("While"), context_); - mlir::BoolAttr val = - builder_.getBoolAttr(node_type_name == "StatelessWhile"); - result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - } + // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common + // Case/If/While op in MLIR and add the differentiating attribute. + if (node.IsCaseNode()) composite_control_flow_op("Case"); + if (node.IsIfNode()) composite_control_flow_op("If"); + if (node.IsWhileNode()) composite_control_flow_op("While"); // Register the mapping between the TF node and the newly created operation. node_values_[node.id()] = @@ -2139,6 +2144,7 @@ StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, llvm::StringRef func_name) { + LoadImporterDialects(*context); mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -2873,7 +2879,7 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) { mlir::OpBuilder builder(func.getBody()); llvm::SmallVector new_input_types; for (int i = 0, e = func.getNumArguments(); i < e; i++) { - auto arg = func.front().getArgument(i); + auto arg = func.getArgument(i); auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType< mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table); if (global_tensor) { @@ -3195,6 +3201,7 @@ Status CreateSavedModelIR( StatusOr SavedModelObjectGraphImporter::Convert( SavedModelV2Bundle* saved_model, absl::Span exported_names, mlir::MLIRContext* context, bool add_default_attributes) { + LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3274,6 +3281,7 @@ class SavedModelSignatureDefImporter { static StatusOr Convert( const SavedModelBundle& bundle, absl::Span exported_names, mlir::MLIRContext* context, bool upgrade_legacy) { + LoadImporterDialects(*context); SavedModelSignatureDefImporter importer(bundle, exported_names, context); TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy)); return importer.ConvertSignatures(); @@ -3646,6 +3654,8 @@ stream_executor::port::StatusOr ConvertFunctionToMlir( tensorflow::GraphDebugInfo dummy_debug_info; tensorflow::GraphImportConfig specs; specs.graph_as_function = true; + for (const auto* control_ret_node : fbody->control_ret_nodes) + specs.control_outputs.push_back(control_ret_node->name()); return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, flib_def, specs, name); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 1c7988d3a40..58377661a23 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -219,22 +219,18 @@ StatusOr GraphdefToSplattedMlirTranslateFunction( if (auto attr = inst.getAttrOfType(attr_id)) { mlir::Attribute rand_val; mlir::Type element_type = attr.getType().getElementType(); + if (element_type.isa()) { + rand_val = mlir::IntegerAttr::get(element_type, std::rand()); + } else if (element_type.isF16() || element_type.isF32() || + element_type.isF64()) { + rand_val = mlir::FloatAttr::get(element_type, + std::rand() * 1.0 / RAND_MAX); - switch (element_type.getKind()) { - case mlir::StandardTypes::Integer: - rand_val = mlir::IntegerAttr::get(element_type, std::rand()); - break; - case mlir::StandardTypes::F16: - case mlir::StandardTypes::F32: - case mlir::StandardTypes::F64: - rand_val = mlir::FloatAttr::get(element_type, - std::rand() * 1.0 / RAND_MAX); - break; - default: - inst.emitWarning() - << "Skipping splat conversion for " - << "an unsupported attribute type " << element_type; - continue; + } else { + inst.emitWarning() + << "Skipping splat conversion for " + << "an unsupported attribute type " << element_type; + continue; } auto new_attr = mlir::DenseElementsAttr::get(attr.getType(), rand_val); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index b646e14b71d..f63cb091a09 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" @@ -86,6 +87,9 @@ static LogicalResult MlirToGraphdefTranslateFunction( } static TranslateFromMLIRRegistration mlir_to_graphdef_translate( - "mlir-to-graphdef", MlirToGraphdefTranslateFunction); + "mlir-to-graphdef", MlirToGraphdefTranslateFunction, + [](DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + }); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 5236bdeffbf..22e6559a0f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" namespace mlir { @@ -67,6 +68,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module, // Test only translation to convert a simple MLIR module with a single TF // dialect op to NodeDef. static TranslateFromMLIRRegistration translate_from_mlir_registration( - "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef); + "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef, + mlir::RegisterAllTensorFlowDialects); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h similarity index 66% rename from tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h rename to tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 599a8df63d7..bd81cae5730 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -36,7 +36,18 @@ inline void CopyUnderscoredAttributes(Operation *from, Operation *to) { }); } +// Copies attributes that are either `device` or whose name begins with an _ +// from `from` to `to`. +// TODO(b/158769932): This should be a general feature instead post some policy +// discussion. +inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) { + auto device = mlir::Identifier::get("device", from->getContext()); + CopyAttributes(from, to, [&device](const NamedAttribute &attr) { + return attr.first.strref().front() == '_' || attr.first == device; + }); +} + } // namespace TF } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index f06fe1280f0..bf894a6c551 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "absl/types/optional.h" +#include "absl/types/variant.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -31,12 +32,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -49,12 +49,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" @@ -62,34 +64,19 @@ limitations under the License. namespace tensorflow { namespace { -// Parses the MLIR module from the mlir_module_string. -Status ParseMlirModule(llvm::StringRef mlir_module_string, - mlir::MLIRContext* mlir_context, - mlir::OwningModuleRef* mlir_module) { - TF_RET_CHECK(!mlir_module_string.empty()) - << "unexpected empty serialized MLIR module string"; - TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; - - // Make sure we catch any error reported by MLIR and forward it to the TF - // error reporting system. - mlir::StatusScopedDiagnosticHandler error_handler(mlir_context); - - // Parse the module. - *mlir_module = mlir::parseSourceString(mlir_module_string, mlir_context); - if (!*mlir_module) { - return error_handler.Combine( - errors::InvalidArgument("could not parse MLIR module")); +// Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape, +// that is converted to a TensorShape. +StatusOr GetTensorShapeFromXlaArgument(const XlaArgument& arg) { + if (absl::holds_alternative(arg.shape)) { + TensorShape arg_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &arg_shape)); + return arg_shape; + } else { + return absl::get(arg.shape); } - - return Status::OK(); } -// Arguments to a computation can be either a tensor or resource. -struct TensorOrResourceShape { - TensorShape shape; - bool is_resource = false; -}; - // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, @@ -276,69 +263,69 @@ Status RefineShapes(llvm::ArrayRef arg_shapes, return Status::OK(); } -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +static void RegisterDialects(mlir::DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); } } // namespace +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, + llvm::MutableArrayRef> + custom_legalization_passes) { + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); + pm.addPass(mlir::TF::CreateStackOpsDecompositionPass()); + pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); + pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); + pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Guarantee all functions have one use, which enables shape inference. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // LegalizeTFControlFlow encapsulates arguments for control flow operations + // with a tuple argument which break the assumption of resource lifting + // inside PromoteResourcesToArgs. + pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); + + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); + for (auto& target_pass : custom_legalization_passes) { + pm.addNestedPass(std::move(target_pass)); + } + pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + // Run shape inference pass to propagate shapes through tensor_cast operations + // from static to dynamic shapes. This could be generated if the shape + // inference was originally missing in a TF op but the corresponding HLO op + // had static shape after lowering. + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can + // expose more graph pruning and canonicalization opportunities that are + // necessary for the second LegalizeTFPass(allow_partial_conversion=false) + // invocation. + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); +} + Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); - tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); - tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); - tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); - tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); - tf2xla.addPass(mlir::createSymbolDCEPass()); - // Guarantee all functions have one use, which enables shape inference. - tf2xla.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - // LegalizeTFControlFlow encapsulates arguments for control flow operations - // with a tuple argument which break the assumption of resource lifting - // inside PromoteResourcesToArgs. - tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(true)); - for (auto& target_pass : custom_legalization_passes) { - tf2xla.addNestedPass(std::move(target_pass)); - } - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - - // Leverage tf2xla kernels for ops that didn't get lowered in the previous - // legalization pass. - tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type)); - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - - // Run shape inference pass to propagate shapes through tensor_cast operations - // from static to dynamic shapes. This could be generated if the shape - // inference was originally missing in a TF op but the corresponding HLO op - // had static shape after lowering. - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - - // Run LegalizeTFPass again because the previous legalization passes can - // expose more graph pruning and canonicalization opportunities that are - // necessary for the second LegalizeTFPass(allow_partial_conversion=false) - // invocation. - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); - // In order to export to XLA, we must sink constants to control flow regions, - // since XLA uses functional control flow. - tf2xla.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); + CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, + custom_legalization_passes); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling @@ -369,12 +356,13 @@ Status ConvertMLIRToXlaComputation( return Status::OK(); } -static Status CompileMlirToXlaHlo( +Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, - llvm::StringRef device_type, bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { + llvm::MutableArrayRef> + custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -391,9 +379,8 @@ static Status CompileMlirToXlaHlo( compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, device_type, compilation_result->computation.get(), - use_tuple_args, - /*return_tuple=*/true, shape_representation_fn, - std::move(custom_legalization_passes))); + use_tuple_args, use_return_tuple, shape_representation_fn, + custom_legalization_passes)); // Construct mapping from XlaComputation's arg to input edges of execute // node. @@ -420,21 +407,22 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef device_type, bool use_tuple_args, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { - RegisterDialects(); + llvm::MutableArrayRef> + custom_legalization_passes) { mlir::MLIRContext mlir_context; + RegisterDialects(mlir_context.getDialectRegistry()); mlir::OwningModuleRef mlir_module; TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module)); llvm::SmallVector tensor_or_resource_shapes; tensor_or_resource_shapes.reserve(arg_shapes.size()); for (const auto& arg_shape : arg_shapes) tensor_or_resource_shapes.push_back({arg_shape}); return CompileMlirToXlaHlo(mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args, - shape_representation_fn, compilation_result, - std::move(custom_legalization_passes)); + /*use_return_tuple=*/true, shape_representation_fn, + compilation_result, custom_legalization_passes); } // Rewrites the given module with specified args. For each of the constant args, @@ -442,8 +430,8 @@ Status CompileSerializedMlirToXlaHlo( // removed from the signature. For resource args, their subtypes are populated. // Returns the original indices for the other arguments on success. static StatusOr> RewriteWithArgs( - mlir::ModuleOp module, llvm::ArrayRef args) { - mlir::FuncOp main_fn = module.lookupSymbol("main"); + mlir::ModuleOp module_op, llvm::ArrayRef args) { + mlir::FuncOp main_fn = module_op.lookupSymbol("main"); std::vector params; bool has_resource_args = false; @@ -455,7 +443,9 @@ static StatusOr> RewriteWithArgs( if (xla_arg.kind == XlaArgument::kResource) { mlir::Type element_type; TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type)); - auto resource_shape = absl::get(xla_arg.shape).dim_sizes(); + TF_ASSIGN_OR_RETURN(TensorShape arg_shape, + GetTensorShapeFromXlaArgument(xla_arg)); + auto resource_shape = arg_shape.dim_sizes(); llvm::SmallVector resource_subtype_shape( resource_shape.begin(), resource_shape.end()); auto resource_subtype = @@ -481,7 +471,7 @@ static StatusOr> RewriteWithArgs( ConvertTensor(xla_arg.constant_value, &builder)); // TODO(hinsu): Use the actual location of the constant. auto constant = builder.create( - mlir::UnknownLoc::get(module.getContext()), value_attr); + mlir::UnknownLoc::get(module_op.getContext()), value_attr); mlir_arg.replaceAllUsesWith(constant); args_to_erase.push_back(idx); } @@ -503,45 +493,66 @@ static StatusOr> RewriteWithArgs( } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, - llvm::StringRef device_type, bool use_tuple_args, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + mlir::ModuleOp module_op, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes) { - RegisterDialects(); - - mlir::MLIRContext context; - GraphImportConfig config; - config.graph_as_function = true; - auto module_or = - ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); - if (!module_or.ok()) return module_or.status(); - - mlir::ModuleOp module = module_or.ValueOrDie().get(); + llvm::MutableArrayRef> + custom_legalization_passes) { TF_ASSIGN_OR_RETURN(std::vector remaining_params, - RewriteWithArgs(module, {args.data(), args.size()})); + RewriteWithArgs(module_op, args)); llvm::SmallVector arg_shapes; arg_shapes.reserve(remaining_params.size()); for (unsigned idx : remaining_params) { const auto& arg = args[idx]; - arg_shapes.push_back({absl::get(arg.shape), + TF_ASSIGN_OR_RETURN(TensorShape arg_shape, + GetTensorShapeFromXlaArgument(arg)); + arg_shapes.push_back({arg_shape, /*is_resource=*/arg.kind == XlaArgument::kResource}); } - mlir::PassManager pm(&context); + mlir::PassManager pm(module_op.getContext()); mlir::TF::StandardPipelineOptions tf_options; mlir::TF::CreateTFStandardPipeline(pm, tf_options); { - mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - if (failed(pm.run(module))) return diag_handler.ConsumeStatus(); + mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext()); + if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus(); } auto status = CompileMlirToXlaHlo( - module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, - compilation_result, std::move(custom_legalization_passes)); + module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple, + shape_representation_fn, compilation_result, custom_legalization_passes); compilation_result->input_mapping = remaining_params; return status; } +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes) { + mlir::MLIRContext context; + RegisterDialects(context.getDialectRegistry()); + GraphImportConfig config; + config.graph_as_function = true; + // Disable shape inference during import as some TensorFlow op fails during + // shape inference with dynamic shaped operands. This in turn causes the + // import to fail. Shape inference during import is going to be removed and + // the shape inference pass is run early in the pass pipeline, shape inference + // during import is not necessary. + config.enable_shape_inference = false; + auto module_or = + ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); + if (!module_or.ok()) return module_or.status(); + + mlir::ModuleOp module_op = module_or.ValueOrDie().get(); + return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args, + /*use_return_tuple=*/true, + shape_representation_fn, compilation_result, + custom_legalization_passes); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 5c64a65ecbd..dac1c994d03 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -16,10 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -30,6 +33,14 @@ limitations under the License. namespace tensorflow { +// Populates the supplied passmanager with the passes required to run the +// TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes +// can be populated in `custom_legalization_passes`. +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, + llvm::MutableArrayRef> + custom_legalization_passes); + // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. @@ -61,7 +72,24 @@ Status ConvertMLIRToXlaComputation( xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes = {}); + +// Helper struct representing argument tensor or resource handle shapes. +struct TensorOrResourceShape { + TensorShape shape; + bool is_resource = false; +}; + +// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and +// stores them in CompilationResult. +Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. @@ -70,17 +98,33 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef device_type, bool use_tuple_args, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes = {}); -// Same as the above but takes input as TensorFlow Graph. +// Compiles a TensorFlow Graph (already converted to MLIR, imported with +// tf_executor dialect still present) into XLA HLO, generates all accompanying +// metadata and stores them in CompilationResult. This will rewrite arguments +// and run the TensorFlow standard pipeline prior to invoking +// `CompileMlirToXlaHlo`. +Status CompileGraphToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes); + +// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata +// and stores them in CompilationResult. // TODO(lyandy): Allow populating of targets/control outputs. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, - std::vector> custom_legalization_passes = {}); + llvm::MutableArrayRef> + custom_legalization_passes = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc new file mode 100644 index 00000000000..57267ff027f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" + +namespace { +void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) { + tensorflow::CreateConvertMlirToXlaHloPipeline( + pm, /*device_type=*/"XLA_CPU_JIT", + /*custom_legalization_passes=*/{}); +} + +mlir::PassPipelineRegistration<> pipeline( + "tf-to-hlo-pipeline", + "Convert TF dialect to HLO dialect (used for compilation in bridge).", + CreateConvertMlirToXlaHloPipelineWithDefaults); +} // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc deleted file mode 100644 index 6ebf6897bb1..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ /dev/null @@ -1,542 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" - -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace tensorflow { -namespace { - -// A dummy shape representation function that simply converts given shape into -// an xla::Shape without assigning any layouts. -xla::StatusOr TestShapeRepresentation(const TensorShape& shape, - DataType type, - bool use_fast_memory) { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); - return xla_shape; -} - -TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { - constexpr char invalid_mlir_module[] = - "totally @invalid MLIR module {here} <-"; - std::vector arg_shapes; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); - EXPECT_EQ(s.ToString(), - "Invalid argument: could not parse MLIR module-:1:1: error: " - "custom op 'totally' is unknown\n"); -} - -constexpr llvm::StringRef kBinaryAddModule = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor - return %0 : tensor - } - } -)"; - -TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { - %arg_tuple.1 = (f32[], f32[]) parameter(0) - %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1 - %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3) - ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); - - // Expect an in order input mapping. - EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); - - // Expect a single tuple-shape, containing two F32 scalars. - EXPECT_EQ(compilation_result.xla_input_shapes.size(), 1); - xla::Shape expected_input_shape = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), - xla::ShapeUtil::MakeShape(xla::F32, {})}); - EXPECT_EQ(compilation_result.xla_input_shapes.front(), expected_input_shape); - - // Expect output shape is a tuple shape containing a single F32 Scalar type. - const xla::Shape output_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - const xla::Shape tuple_output_shape = - xla::ShapeUtil::MakeTupleShape({output_shape}); - EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - - // Expect exactly 1 OutputDescription. - EXPECT_EQ(compilation_result.outputs.size(), 1); - const XlaCompiler::OutputDescription& output_desc = - compilation_result.outputs.front(); - EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); - EXPECT_EQ(output_desc.shape, TensorShape()); - EXPECT_FALSE(output_desc.is_constant); - EXPECT_FALSE(output_desc.is_tensor_list); - - // Expect no resource updates from computation. - EXPECT_TRUE(compilation_result.resource_updates.empty()); -} - -TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.5 - -ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { - %Arg_0.1 = f32[] parameter(0) - %Arg_1.2 = f32[] parameter(1) - %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) - ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); - - // Expect an in order input mapping. - EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); - - // Expect two inputs, each containing a F32 scalar. - EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2); - xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); - EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape); - EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape); - - // Expect output shape is a tuple shape containing a single F32 Scalar type. - const xla::Shape output_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - const xla::Shape tuple_output_shape = - xla::ShapeUtil::MakeTupleShape({output_shape}); - EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - - // Expect exactly 1 OutputDescription. - EXPECT_EQ(compilation_result.outputs.size(), 1); - const XlaCompiler::OutputDescription& output_desc = - compilation_result.outputs.front(); - EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); - EXPECT_EQ(output_desc.shape, TensorShape()); - EXPECT_FALSE(output_desc.is_constant); - EXPECT_FALSE(output_desc.is_tensor_list); - - // Expect no resource updates from computation. - EXPECT_TRUE(compilation_result.resource_updates.empty()); -} - -// Tests that foldable ops are constant-folded to enable legalization of ops -// that require compile time constant operand. -TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { - // "tf.Shape" can only be folded away after shape inference. tf.Reshape can - // only be lowered when tf.Shape is folded into a constant. - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { - %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> - %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> - return %1 : tensor<10x19xf32> - } - } - )"; - - std::vector arg_shapes{TensorShape({10, 19}), - TensorShape({19, 10})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { - %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} - %get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1 - %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3) - ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { - %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor - return %0 : tensor - } - } - )"; - - std::vector arg_shapes{TensorShape({10, 17}), - TensorShape({17, 19})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - - constexpr char expected_signature[] = - R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))"; - EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), - ::testing::HasSubstr(expected_signature)); -} - -TEST(CompileSerializedMlirToXlaHloTest, ShapeInferenceAfterLegalization) { - constexpr char mlir_module[] = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) { - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) - return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32> - } - } - )"; - - std::vector arg_shapes{TensorShape({8, 16, 16, 64}), - TensorShape({64})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - - constexpr char expected_signature[] = - R"(-> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]))"; - EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), - ::testing::HasSubstr(expected_signature)); -} - -TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main() -> (tensor<0xi32>, tensor<0xi32>) { - %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>) - return %r0, %r1 : tensor<0xi32>, tensor<0xi32> - } -} -)"; - - std::vector arg_shapes(2, TensorShape()); - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.4 - -ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { - %arg_tuple.1 = () parameter(0) - %constant.2 = s32[0]{0} constant({}) - ROOT %tuple.3 = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} %constant.2, s32[0]{0} %constant.2) -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// The following xla::OpSharding protos are used: -// Serialized string: -// "\08\03\1A\02\01\02\22\02\00\01" -// Proto debug string: -// type: OTHER -// tile_assignment_dimensions: 1 -// tile_assignment_dimensions: 2 -// tile_assignment_devices: 0 -// tile_assignment_devices: 1 -// -// Serialized string: -// "\08\01\1A\01\01\22\01\00" -// Proto debug string: -// type: MAXIMAL -// tile_assignment_dimensions: 1 -// tile_assignment_devices: 0 -// -// Serialized string: -// "" -// Proto debug string (empty but would equivalent to): -// type: REPLICATED -TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) { - return - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10}), - TensorShape({10, 1024}), - TensorShape({128, 1024})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.6 - -ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { - %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} - %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 - %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 - %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 - ROOT %tuple.5 = () tuple() -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) { - return - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - ASSERT_FALSE(s.ok()); - EXPECT_EQ(s.error_message(), - "failed to parse argument sharding 0 'bad_sharding'"); -} - -TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) { - constexpr char mlir_module[] = R"( -module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { - func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) { - return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> - } -} -)"; - - std::vector arg_shapes{TensorShape({128, 10}), - TensorShape({10, 1024}), - TensorShape({128, 1024})}; - XlaCompiler::CompilationResult compilation_result; - - Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, "XLA_CPU_JIT", - /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); - TF_ASSERT_OK(s); - - const xla::HloModuleConfig module_config( - compilation_result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - compilation_result.computation->proto(), module_config); - TF_ASSERT_OK(status_or_hlo_module.status()); - constexpr char expected_hlo_module_string[] = R"(HloModule main.9 - -ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { - %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) - %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 - %reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2) - %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 - %reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3) - %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 - %reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4) - ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} -} - -)"; - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// Verify that conversion from Graph to MLIR and empty shape representation -// function is successful. -TEST(CompileGraphToXlaHlo, Basic) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - Graph graph(OpRegistry::Global()); - - Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); - test::graph::Retval(&graph, 0, arg); - - XlaCompiler::CompilationResult result; - XlaCompiler::Argument compiler_arg; - compiler_arg.kind = XlaCompiler::Argument::kParameter; - compiler_arg.shape = TensorShape(); - - TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT", - /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), - /*shape_representation_fn=*/nullptr, &result)); - - const xla::HloModuleConfig module_config( - result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); - - constexpr char expected_hlo_module_string[] = R"(HloModule main.3 - -ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { - %Arg_0.1 = f32[] parameter(0) - ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) -} - -)"; - - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -// Tests a conversion from Graph to MLIR with resource arguments. -TEST(CompileGraphToXlaHlo, Resources) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - Graph graph(OpRegistry::Global()); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0); - auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); - auto assign = - ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val); - TF_ASSERT_OK(scope.ToGraph(&graph)); - - XlaCompiler::CompilationResult result; - XlaCompiler::Argument arg0; - arg0.kind = XlaCompiler::Argument::kParameter; - arg0.shape = TensorShape({2}); - XlaCompiler::Argument arg1; - arg1.kind = XlaCompiler::Argument::kResource; - arg1.shape = TensorShape({2}); - arg1.type = DT_FLOAT; - - TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT", - /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), - /*shape_representation_fn=*/nullptr, &result)); - - EXPECT_EQ(result.outputs.size(), 0); - ASSERT_EQ(result.resource_updates.size(), 1); - const auto& resource_update = result.resource_updates[0]; - EXPECT_EQ(resource_update.input_index, 1); - EXPECT_EQ(resource_update.modified, true); - EXPECT_EQ(resource_update.shape, TensorShape({2})); - EXPECT_EQ(resource_update.type, DT_FLOAT); - - const xla::HloModuleConfig module_config( - result.computation->GetProgramShape().ValueOrDie()); - auto status_or_hlo_module = xla::HloModule::CreateFromProto( - result.computation->proto(), module_config); - ASSERT_TRUE(status_or_hlo_module.ok()); - - constexpr char expected_hlo_module_string[] = - R"(HloModule main.4, input_output_alias={ {0}: 1 } - -ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) { - %Arg_1.2 = f32[2]{0} parameter(1) - %Arg_0.1 = f32[2]{0} parameter(0) - ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1) -} - -)"; - - EXPECT_EQ(expected_hlo_module_string, - status_or_hlo_module.ValueOrDie()->ToString()); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 359314a64b0..05e1f059029 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -36,8 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" @@ -161,7 +161,7 @@ StatusOr ConvertTensor(const Tensor& input_tensor, default: // TODO(shpeisman): restructure code to reuse dialect pointer across // calls. - auto* dialect = builder->getContext()->getRegisteredDialect("tf"); + auto* dialect = builder->getContext()->getLoadedDialect("tf"); return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor)); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index bf96e3d1df4..6266a5e2195 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -33,16 +34,13 @@ limitations under the License. namespace tensorflow { namespace { -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - return true; - }(); - (void)init_once; +static void RegisterDialects(mlir::MLIRContext &context) { + context.loadDialect(); } TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { mlir::MLIRContext context; + RegisterDialects(context); mlir::Builder b(&context); PartialTensorShape output_shape = @@ -52,6 +50,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) { mlir::MLIRContext context; + RegisterDialects(context); mlir::Builder b(&context); PartialTensorShape output_shape = ConvertTypeToTensorShape( @@ -61,6 +60,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) { TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) { mlir::MLIRContext context; + RegisterDialects(context); mlir::Builder b(&context); PartialTensorShape output_shape = ConvertTypeToTensorShape( @@ -77,8 +77,8 @@ TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) { } TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { - RegisterDialects(); mlir::MLIRContext context; + RegisterDialects(context); mlir::Builder b(&context); // Create the sample tensor to convert. @@ -123,9 +123,8 @@ class ConvertTensorTest : public ::testing::Test { }; TEST_F(ConvertTensorTest, Simple) { - RegisterDialects(); - mlir::MLIRContext context; + RegisterDialects(context); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context))); ASSERT_NO_FATAL_FAILURE( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 0caceb69510..0d035e8f864 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -91,64 +91,62 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { } Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { - switch (type.getKind()) { - case mlir::StandardTypes::F16: - *dtype = DT_HALF; - return Status::OK(); - case mlir::StandardTypes::F32: - *dtype = DT_FLOAT; - return Status::OK(); - case mlir::StandardTypes::F64: - *dtype = DT_DOUBLE; - return Status::OK(); - case mlir::StandardTypes::BF16: - *dtype = DT_BFLOAT16; - return Status::OK(); - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - *dtype = DT_BOOL; - return Status::OK(); - case 8: - *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; - return Status::OK(); - case 16: - *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; - return Status::OK(); - case 32: - *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; - return Status::OK(); - case 64: - *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; - return Status::OK(); - default: - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } - } - case mlir::StandardTypes::Complex: { - auto etype = type.cast().getElementType(); - if (etype.isF32()) { - *dtype = DT_COMPLEX64; - return Status::OK(); - } else if (etype.isF64()) { - *dtype = DT_COMPLEX128; - return Status::OK(); - } - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case mlir::TF::TensorFlowTypes::enumerant: \ - *dtype = DT_##enumerant; \ + if (type.isF16()) { + *dtype = DT_HALF; return Status::OK(); + } else if (type.isF32()) { + *dtype = DT_FLOAT; + return Status::OK(); + } else if (type.isF64()) { + *dtype = DT_DOUBLE; + return Status::OK(); + } else if (type.isBF16()) { + *dtype = DT_BFLOAT16; + return Status::OK(); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + *dtype = DT_BOOL; + return Status::OK(); + case 8: + *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; + return Status::OK(); + case 16: + *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; + return Status::OK(); + case 32: + *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; + return Status::OK(); + case 64: + *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; + return Status::OK(); + default: + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); + } + } else if (auto complex_type = type.dyn_cast()) { + auto etype = complex_type.getElementType(); + if (etype.isF32()) { + *dtype = DT_COMPLEX64; + return Status::OK(); + } else if (etype.isF64()) { + *dtype = DT_COMPLEX128; + return Status::OK(); + } + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); + } + +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (type.isa()) { \ + *dtype = DT_##enumerant; \ + return Status::OK(); \ + } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } + + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); } Status ConvertToDataType(Type type, DataType* dtype) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index bf0b3b75ace..81892934efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -25,6 +26,8 @@ limitations under the License. #include "llvm/Support/Regex.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" @@ -155,4 +158,19 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, llvm::formatv("unsupported '{0}' attribute", kDevicesAttr)); } +mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc, + llvm::StringRef device, + int64_t* device_ordinal) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName( + absl::string_view(device.data(), device.size()), &parsed_name)) + return mlir::emitError(loc) << "invalid device '" << device << "'"; + + if (!parsed_name.has_id) + return mlir::emitError(loc) << "device '" << device << "' has no id"; + + *device_ordinal = parsed_name.id; + return mlir::success(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h index 893e118024c..14e48bf7710 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" @@ -41,6 +42,12 @@ void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set); mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, mlir::TF::RuntimeDevices* devices); +// Parses a device string and returns its ordinal (id). This will return an +// error if the device string is invalid or has no id. +mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc, + llvm::StringRef device, + int64_t* device_ordinal); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index bc849e1d116..1da1f5973f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -205,5 +205,47 @@ TEST(DeviceUtilTest, GetGpuDeviceMetadata) { ASSERT_FALSE(meta_1.hasValue()); } +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceString) { + const std::string tpu0 = "/job:worker/replica:0/task:0/device:TPU:0"; + const std::string tpu1 = "/job:worker/replica:0/task:0/device:TPU:1"; + + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal0 = -1; + mlir::LogicalResult result0 = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu0, &device_ordinal0); + EXPECT_TRUE(mlir::succeeded(result0)); + EXPECT_EQ(device_ordinal0, 0); + + int64_t device_ordinal1 = -1; + mlir::LogicalResult result1 = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu1, &device_ordinal1); + EXPECT_TRUE(mlir::succeeded(result1)); + EXPECT_EQ(device_ordinal1, 1); +} + +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceStringInvalid) { + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal = -1; + mlir::LogicalResult result = GetDeviceOrdinalFromDeviceString( + unknown_loc, "bad_device", &device_ordinal); + EXPECT_TRUE(mlir::failed(result)); +} + +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceStringNoId) { + const std::string tpu_no_id = "/job:worker/replica:0/task:0/device:TPU"; + + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal = -1; + mlir::LogicalResult result = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu_no_id, &device_ordinal); + EXPECT_TRUE(mlir::failed(result)); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 4feb3837357..b5f2acc581d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -27,7 +27,7 @@ limitations under the License. namespace mlir { // TensorFlow's Status is used for error reporting back to callers. -using tensorflow::Status; +using ::tensorflow::Status; // Diagnostic handler that collects all the diagnostics reported and can produce // a Status to return to callers. This is for the case where MLIR functions are diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 0364b935b92..67c2aebf121 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -227,25 +228,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } -// Updates NodeDef constructed out of an MLIR If op to map it to either -// TensorFlow StatelessIf or If op depending on the additional attribute. -void UpdateCompositeIfOp(NodeDef* node_def) { +// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to +// either TensorFlow StatelessX or X op depending on the additional attribute. +void UpdateCompositeOp(NodeDef* node_def) { auto it = node_def->mutable_attr()->find("is_stateless"); if (it != node_def->attr().end()) { if (it->second.b()) { - *node_def->mutable_op() = "StatelessIf"; - } - node_def->mutable_attr()->erase(it); - } -} - -// Updates NodeDef constructed out of an MLIR While op to map it to either -// TensorFlow StatelessWhile or While op depending on the additional attribute. -void UpdateCompositeWhileOp(NodeDef* node_def) { - auto it = node_def->mutable_attr()->find("is_stateless"); - if (it != node_def->attr().end()) { - if (it->second.b()) { - *node_def->mutable_op() = "StatelessWhile"; + *node_def->mutable_op() = "Stateless" + node_def->op(); } node_def->mutable_attr()->erase(it); } @@ -352,8 +341,9 @@ StatusOr> GetOperationNodeDef( TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); - if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get()); - if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get()); + if (node_def->op() == "Case") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "If") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "While") UpdateCompositeOp(node_def.get()); return node_def; } @@ -379,65 +369,36 @@ Status ConvertAttributes( name = mangling_util::DemangleAttributeName(name); } AttrValue value; - switch (attr.getKind()) { - case mlir::StandardAttributes::SymbolRef: { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - func_call_attrs[string(name)] = value; - continue; - } - case mlir::StandardAttributes::Integer: - if (auto boolAttr = attr.dyn_cast()) { - TF_RETURN_IF_ERROR(ConvertAttribute(boolAttr, &value)); - } else { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - } - break; - case mlir::StandardAttributes::Float: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::String: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Array: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::DenseIntOrFPElements: - case mlir::StandardAttributes::DenseStringElements: - case mlir::StandardAttributes::OpaqueElements: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Type: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Unit: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case static_cast(mlir::TF::AttrKind::SHAPE): - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case static_cast(mlir::TF::AttrKind::FUNC): { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - func_call_attrs[string(name)] = value; - continue; - } - // AffineMap kind is not implemented. - case mlir::StandardAttributes::AffineMap: - return errors::Unimplemented("AffineMap attribute (needed for '", - name_strref, "') unimplemented"); - default: - return errors::Unimplemented("Unhandled attribute kind for attribute '", - name_strref, '\''); + if (auto symbol_ref = attr.dyn_cast()) { + TF_RETURN_IF_ERROR( + ConvertAttribute(symbol_ref.cast(), &value)); + func_call_attrs[string(name)] = value; + continue; } + if (auto func_attr = attr.dyn_cast()) { + TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, &value)); + func_call_attrs[string(name)] = value; + continue; + } + if (attr.isa()) { + // AffineMapAttr is not implemented. + return errors::Unimplemented("AffineMap attribute (needed for '", + name_strref, "') unimplemented"); + } + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(attr) + .Case( + [&](auto derived_attr) { + return ConvertAttribute(derived_attr, &value); + }) + .Default([&](mlir::Attribute) { + return errors::Unimplemented( + "Unhandled attribute kind for attribute '", name_strref, + '\''); + })); + // According to the NodeDef proto definition, an attribute name from the // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in // the attribute from MLIR, it is treated as an attribute from function diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc new file mode 100644 index 00000000000..8e9495c0454 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +std::string SerializeMlirModule(mlir::ModuleOp module_op) { + std::string serialized_mlir_module; + llvm::raw_string_ostream os(serialized_mlir_module); + mlir::OpPrintingFlags print_flags; + print_flags.enableDebugInfo(); + module_op.print(os, print_flags); + return std::move(os.str()); +} + +Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, + mlir::MLIRContext* mlir_context, + mlir::OwningModuleRef* mlir_module) { + TF_RET_CHECK(!serialized_mlir_module.empty()) + << "unexpected empty serialized MLIR module string"; + TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; + + // Make sure we catch any error reported by MLIR and forward it to the TF + // error reporting system. + mlir::StatusScopedDiagnosticHandler error_handler(mlir_context); + + // Parse the module. + *mlir_module = mlir::parseSourceString(serialized_mlir_module, mlir_context); + if (!*mlir_module) + return error_handler.Combine( + errors::InvalidArgument("could not parse MLIR module")); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h new file mode 100644 index 00000000000..12d1c39132e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Prints a MLIR module `module_op` and returns it as a string. +std::string SerializeMlirModule(mlir::ModuleOp module_op); + +// Parses a MLIR module from `mlir_module_string` into `mlir_module` with +// context `mlir_context`. +Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, + mlir::MLIRContext* mlir_context, + mlir::OwningModuleRef* mlir_module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc new file mode 100644 index 00000000000..d1815a4a88b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -0,0 +1,434 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" + +#include +#include +#include +#include + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +#define DEBUG_TYPE "tf-shape-inference-utils" + +using ::tensorflow::int64; +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +namespace mlir { +namespace TF { + +namespace { + +// Extracts attributes from a MLIR operation, including derived attributes. +NamedAttrList GetAllAttributesFromOperation(Operation* op) { + NamedAttrList attr_list; + attr_list.append(op->getAttrDictionary().getValue()); + + if (auto derived = dyn_cast(op)) { + auto materialized = derived.materializeDerivedAttributes(); + attr_list.append(materialized.getValue()); + } + + return attr_list; +} + +// Extracts a PartialTensorShape from the MLIR type. +Optional GetShapeFromMlirType(Type t) { + if (auto ranked_type = t.dyn_cast()) { + // Convert the MLIR shape indices (int64_t) to TensorFlow indices + // (int64). + ArrayRef shape = ranked_type.getShape(); + SmallVector tf_shape(shape.begin(), shape.end()); + return tensorflow::PartialTensorShape( + MutableArrayRefToSpan(tf_shape)); + } + return None; +} + +// Gets the subtype's shape and data type for `type`. Templated to support both +// ResourceType and VariantType. +template +std::unique_ptr>> +GetSubtypesHelper(Type type) { + auto type_with_subtypes = + type.cast().getElementType().dyn_cast(); + if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { + return nullptr; + } + auto shapes_and_types = std::make_unique>>(); + for (auto subtype : type_with_subtypes.getSubtypes()) { + auto shape = GetShapeFromMlirType(subtype); + // handle_shapes_and_types requires all shapes to be known. So if any + // subtype is unknown, clear the vector. + if (!shape) { + shapes_and_types = nullptr; + break; + } + tensorflow::DataType dtype; + auto status = + tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); + assert(status.ok() && "Unknown element type"); + shapes_and_types->emplace_back(*shape, dtype); + } + return shapes_and_types; +} + +// Gets the subtype's shape and data type for `type`. +std::unique_ptr>> +GetSubtypes(Type type) { + auto subclasses = GetSubtypesHelper(type); + if (subclasses) return subclasses; + return GetSubtypesHelper(type); +} + +// Returns a shape inference function call failure at `location`. +LogicalResult EmitErrorFromShapeFunction(Optional location, + StringRef op_name, + StringRef error_message) { + LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << op_name + << "': " << error_message << "\n"); + return emitOptionalError( + location, + llvm::formatv( + "TensorFlow shape inference function errored for op '{0}': {1}", + op_name, error_message) + .str()); +} + +// Extracts shape from a shape handle and inference context. +Optional> GetShapeFromHandle(InferenceContext& context, + const ShapeHandle& sh) { + if (!context.RankKnown(sh)) return None; + SmallVector shape; + for (int dim : llvm::seq(0, context.Rank(sh))) + shape.push_back(context.Value(context.Dim(sh, dim))); + return shape; +} + +// Creates a tensor type from a shape handle and element type. +TensorType CreateTensorType(InferenceContext& context, const ShapeHandle& sh, + Type element_type) { + auto shape = GetShapeFromHandle(context, sh); + if (shape.hasValue()) + return RankedTensorType::get(shape.getValue(), element_type); + return UnrankedTensorType::get(element_type); +} + +// Creates a ShapedTypeComponent from a shape handle and element type. +ShapedTypeComponents CreateShapedTypeComponents(InferenceContext& context, + const ShapeHandle& sh, + Type element_type) { + auto shape = GetShapeFromHandle(context, sh); + if (shape.hasValue()) + return ShapedTypeComponents(shape.getValue(), element_type); + return ShapedTypeComponents(element_type); +} + +// Runs TensorFlow shape inference associated to the op type registered in the +// TensorFlow op registry based on Graph version, operands, and attributes. +// Invoking this shape function will invoke conversions of parameters to the +// TensorFlow Graph equivalent data structures and back to MLIR equivalent data +// structures. This does not use a natively implemented shape inference in MLIR, +// and instead is temporary until shape functions are reimplemented/migrated to +// being in MLIR instead of the TensorFlow op registry. +LogicalResult InferReturnTypeComponentsFallback( + MLIRContext* context, StringRef op_name, int64_t graph_version, + Optional location, ValueRange operands, + const NamedAttrList& attributes, OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes) { + assert(op_name.startswith(TensorFlowDialect::getDialectNamespace())); + // Drop the `tf.` prefix to query TF registry. + std::string op_type = + op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1) + .str(); + + // Get information from the registry and check if we have a shape function for + // this op. + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(op_type); + if (!op_reg_data) { + LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" + << op_name << "'.\n"); + return emitOptionalError(location, "op is unregistered"); + } + if (!op_reg_data->shape_inference_fn) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping inference for op without shape function '" + << op_name << "'.\n"); + return emitOptionalError(location, "missing shape function"); + } + + // Convert the operation attributes to be able to use the InferenceContext + // and the TensorFlow shape function. + tensorflow::AttrValueMap converted_attributes; + NamedAttrList attributes_to_convert; + // Filter out unregistered attributes. + for (const auto& attr_def : op_reg_data->op_def.attr()) + if (auto registered_attr = attributes.get(attr_def.name())) + attributes_to_convert.set(attr_def.name(), registered_attr); + + auto attrs_status = tensorflow::ConvertAttributes( + attributes_to_convert, /*attrs_to_ignore=*/{}, &converted_attributes); + if (!attrs_status.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Error creating attribute map for '" << op_name + << "': " << attrs_status.error_message() << "\n"); + return emitOptionalError( + location, + "failed to convert attributes to proto map"); + } + + // Collect an array with input values for constant operands and input shapes + // for all the operands. + std::vector input_tensors(operands.size()); + std::vector input_shapes(operands.size()); + std::vector tensors(operands.size()); + std::vector>>> + handle_shapes_and_types(operands.size()); + for (auto it : llvm::enumerate(operands)) { + Value operand = it.value(); + size_t index = it.index(); + + // If the operand is constant, then convert it to Tensor. + if (auto attr = operand_as_constant_fn(operand)) { + tensorflow::Tensor* input_tensor = &tensors[index]; + auto status = + tensorflow::ConvertToTensor(attr.cast(), input_tensor); + if (status.ok()) { + input_tensors[index] = input_tensor; + } else { + LLVM_DEBUG(llvm::dbgs() << "Error converting input " << index + << " of op '" << op_name << "' to Tensor: " + << status.error_message() << "\n"); + } + } + + Type operand_type = operand.getType(); + if (auto shape = GetShapeFromMlirType(operand_type)) { + input_shapes[index] = *shape; + } + // Collect the handle shapes and types for a resource/variant. + handle_shapes_and_types[index] = GetSubtypes(operand_type); + } + + // Perform the shape inference using an InferenceContext with the input + // shapes. This object is abstracting the information that the ShapeInference + // function operates on. + InferenceContext c(graph_version, + tensorflow::AttrSlice(&converted_attributes), + op_reg_data->op_def, input_shapes, input_tensors, + /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) + return EmitErrorFromShapeFunction(location, op_name, + status.error_message()); + + // Determine if, during shape computation, the shape functions attempted to + // query an input operand as shape where the input was not known/constant. + bool requires_inputs = + any_of(llvm::seq(0, c.num_inputs()), [&](int input) { + return c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]; + }); + if (requires_inputs) { + LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); + std::vector input_tensors_as_shapes; + for (int input : llvm::seq(0, c.num_inputs())) { + if (c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]) { + LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); + auto op_result = operands[input].dyn_cast(); + if (!op_result) continue; + // Resize on first valid shape computed. + input_tensors_as_shapes.resize(c.num_inputs()); + auto handle = op_result_as_shape_fn(c, op_result); + LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " + << (handle.Handle() ? "found" : "not found")); + if (handle.Handle()) input_tensors_as_shapes[input] = handle; + } + } + + // Attempt to compute the unknown operands as shapes. + // Note: in the case where no partial outputs could be computed, this + // would be empty. + if (!input_tensors_as_shapes.empty()) { + c.set_input_tensors_as_shapes(input_tensors_as_shapes); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) + return EmitErrorFromShapeFunction(location, op_name, + status.error_message()); + } + } + + // Update the shape for each of the operation result if the InferenceContext + // has more precise shapes recorded. + for (int output : llvm::seq(0, c.num_outputs())) { + ShapeHandle shape_handle = c.output(output); + LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " + << c.DebugString(shape_handle) << "\n"); + + Type new_element_type = result_element_type_fn(output); + // Populate the handle shapes for a resource/variant. + if (new_element_type && + new_element_type.isa()) { + auto handle_shapes_types = c.output_handle_shapes_and_types(output); + if (handle_shapes_types) { + SmallVector subtypes; + Builder b(context); + for (const auto& shape_n_type : *handle_shapes_types) { + Type element_type; + auto status = + tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type); + assert(status.ok() && "Unknown element type"); + subtypes.push_back( + CreateTensorType(c, shape_n_type.shape, element_type)); + } + if (new_element_type.isa()) { + new_element_type = TF::ResourceType::get(subtypes, context); + } else { + new_element_type = TF::VariantType::get(subtypes, context); + } + } + } + inferred_return_shapes.push_back( + CreateShapedTypeComponents(c, shape_handle, new_element_type)); + } + + return success(); +} + +} // namespace + +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes) { + auto attributes = GetAllAttributesFromOperation(op); + return InferReturnTypeComponentsFallback( + op->getContext(), op->getName().getStringRef(), graph_version, location, + op->getOperands(), attributes, operand_as_constant_fn, + op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes); +} + +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + SmallVectorImpl& inferred_return_shapes) { + if (auto type_op = dyn_cast(op)) { + auto attributes = GetAllAttributesFromOperation(op); + SmallVector inferred_return_types; + auto result = type_op.inferReturnTypes( + op->getContext(), location, op->getOperands(), + DictionaryAttr::get(attributes, op->getContext()), op->getRegions(), + inferred_return_types); + if (failed(result)) return failure(); + + inferred_return_shapes.resize(inferred_return_types.size()); + for (auto inferred_return_type : llvm::enumerate(inferred_return_types)) { + if (auto shaped_type = + inferred_return_type.value().dyn_cast()) { + if (shaped_type.hasRank()) { + inferred_return_shapes[inferred_return_type.index()] = + ShapedTypeComponents(shaped_type.getShape(), + shaped_type.getElementType()); + } else { + inferred_return_shapes[inferred_return_type.index()] = + ShapedTypeComponents(shaped_type.getElementType()); + } + } + } + + return success(); + } + + if (auto shape_type_op = dyn_cast(op)) { + auto attributes = GetAllAttributesFromOperation(op); + return shape_type_op.inferReturnTypeComponents( + op->getContext(), location, op->getOperands(), + DictionaryAttr::get(attributes, op->getContext()), op->getRegions(), + inferred_return_shapes); + } + + auto operand_as_constant_fn = [](Value operand) -> Attribute { + Attribute attr; + if (matchPattern(operand, m_Constant(&attr))) return attr; + return nullptr; + }; + + auto op_result_as_shape_fn = [](InferenceContext& ic, + OpResult op_result) -> ShapeHandle { + auto rt = op_result.getType().dyn_cast(); + if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; + + std::vector dims(rt.getDimSize(0), ic.UnknownDim()); + Attribute attr; + if (matchPattern(op_result, m_Constant(&attr))) { + auto elements = attr.dyn_cast(); + if (elements) + for (auto element : llvm::enumerate(elements.getIntValues())) + dims[element.index()] = ic.MakeDim(element.value().getSExtValue()); + } + return ic.MakeShape(dims); + }; + + auto result_element_type_fn = [](int) -> Type { return nullptr; }; + + return InferReturnTypeComponentsForTFOp( + location, op, graph_version, operand_as_constant_fn, + op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h new file mode 100644 index 00000000000..eda2bc49514 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ + +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/framework/shape_inference.h" + +namespace mlir { +namespace TF { + +// Function that takes in a value and extracts a constant from it, if available. +// If the value cannot be resolved as a constant, a nullptr will be returned. +// Certain shape functions require constant values as arguments. +using OperandAsConstantFn = llvm::function_ref; + +// Function that takes in an operation result and computes a shape (can be +// partial) value. Certain shape functions require shape values as arguments. +using OpResultAsShapeFn = + llvm::function_ref; + +// Function that takes a result index and returns the element type. Element +// types are necessary for handle types (resource, variant). +using ResultElementTypeFn = llvm::function_ref; + +// Runs TensorFlow shape inference associated to the op type registered in the +// TensorFlow op registry based on the Graph version, operands, and attributes. +// Invoking this shape function will create conversions of parameters to the +// TensorFlow Graph equivalent data structures and back to MLIR equivalent data +// structures. This does not use a natively implemented shape inference in MLIR, +// and instead is temporary until shape functions are reimplemented/migrated to +// being in MLIR instead of the TensorFlow op registry. +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes); + +// Runs TensorFlow shape inference for an operation for a given Graph version. +// If an operation implements the `InferTypeOpInterface` or +// `InferShapedTypeOpInterface` interfaces, those are used instead but with +// derived attributes populated. Otherwise the above function is used but with +// default `operand_as_constant_fn` and `op_result_as_shape_fn` that only +// extracts a value if the operands are constant (no partial evaluation, and an +// empty `result_element_type_fn`. Element types with subtypes (DT_RESOURCE, +// DT_VARIANT) are not supported. +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + SmallVectorImpl& inferred_return_shapes); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc new file mode 100644 index 00000000000..bcc3fe62f99 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -0,0 +1,334 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/utils/string_container_utils.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +// NOLINTNEXTLINE +llvm::cl::opt input_types( + "tf-xla-input-types", + llvm::cl::desc("XLA input argument types (kinds), separated by ','. " + "Supported types include ['parameter', 'resource']. If " + "empty, all arguments are assumed to be parameters."), + llvm::cl::init("")); + +namespace tensorflow { + +namespace { + +mlir::LogicalResult PrintHloModuleText( + const XlaCompilationResult& compilation_result, llvm::raw_ostream& output) { + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + if (!status_or_hlo_module.ok()) { + LOG(ERROR) << "Conversion to HLO module failed: " + << status_or_hlo_module.status().ToString(); + return mlir::failure(); + } + + xla::HloModule* hlo_module = status_or_hlo_module.ValueOrDie().get(); + + output << hlo_module->ToString(); + + if (!compilation_result.input_mapping.empty()) + output << "// InputMapping {" + << absl::StrJoin(compilation_result.input_mapping, ", ") << "}\n"; + + for (const auto& xla_input_shape : compilation_result.xla_input_shapes) + output << "// XlaInputShape " << xla_input_shape.ToString() << '\n'; + + output << "// XlaOutputShape " + << compilation_result.xla_output_shape.ToString() << '\n'; + + for (const auto& xla_output_description : compilation_result.outputs) { + output << "// XlaOutputDescription type=" + << DataTypeString(xla_output_description.type) << " shape=(" + << absl::StrJoin(xla_output_description.shape.dim_sizes(), ", ") + << ')'; + if (xla_output_description.input_index >= 0) + output << " input_index=" << xla_output_description.input_index; + if (xla_output_description.is_constant) output << " constant"; + if (xla_output_description.is_tensor_list) output << " tensor_list"; + output << '\n'; + } + + for (const auto& resource_update : compilation_result.resource_updates) { + output << "// ResourceUpdate input_index=" << resource_update.input_index + << " type=" << DataTypeString(resource_update.type) << " shape=(" + << absl::StrJoin(resource_update.shape.dim_sizes(), " ") << ')'; + if (resource_update.modified) output << " modified"; + output << '\n'; + } + + return mlir::success(); +} + +Status ParseArgumentShapes( + absl::string_view input_shapes_str, + llvm::SmallVectorImpl& arg_shapes) { + arg_shapes.clear(); + std::vector> input_shapes_vector; + TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector)); + arg_shapes.resize(input_shapes_vector.size()); + for (const auto& shape : llvm::enumerate(input_shapes_vector)) + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( + shape.value(), &arg_shapes[shape.index()].shape)); + + return Status::OK(); +} + +Status ParseDataTypes(absl::string_view data_types_str, + llvm::SmallVectorImpl& data_types) { + data_types.clear(); + std::vector input_dtypes_vector; + TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types_str, input_dtypes_vector)); + data_types.resize(input_dtypes_vector.size(), DT_INVALID); + for (auto data_type : llvm::enumerate(input_dtypes_vector)) { + if (!DataType_Parse(data_type.value(), &data_types[data_type.index()])) + return errors::InvalidArgument("Invalid dtype at index ", + data_type.index(), ": ", + data_type.value()); + const auto& resolved_dtype = data_types[data_type.index()]; + if (resolved_dtype == DT_INVALID || resolved_dtype == DT_STRING || + resolved_dtype == DT_RESOURCE || resolved_dtype == DT_VARIANT || + IsRefType(resolved_dtype)) + return errors::InvalidArgument("Unsupported dtype at index ", + data_type.index(), ": ", + data_type.value()); + } + + return Status::OK(); +} + +Status ParseArgumentKinds( + absl::string_view input_types_str, + llvm::SmallVectorImpl& argument_kinds) { + argument_kinds.clear(); + if (input_types_str.empty()) return Status::OK(); + + std::vector argument_kind_strs = + absl::StrSplit(input_types_str, ','); + argument_kinds.reserve(argument_kind_strs.size()); + for (const auto& argument_kind_str : llvm::enumerate(argument_kind_strs)) { + const auto& value = argument_kind_str.value(); + if (value == "parameter") { + argument_kinds.push_back(XlaArgument::Kind::kParameter); + } else if (value == "resource") { + argument_kinds.push_back(XlaArgument::Kind::kResource); + } else { + return errors::InvalidArgument( + "Unsupported TF/XLA argument kind at index ", + argument_kind_str.index(), ": ", value); + } + } + + return Status::OK(); +} + +Status ParseXlaArguments(absl::string_view input_shapes_str, + absl::string_view input_dtypes_str, + absl::string_view arg_kinds_str, + llvm::SmallVectorImpl& xla_arguments) { + xla_arguments.clear(); + std::vector> input_shapes_vector; + TF_RETURN_IF_ERROR( + tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector)); + llvm::SmallVector dtypes_vector; + TF_RETURN_IF_ERROR(ParseDataTypes(input_dtypes_str, dtypes_vector)); + llvm::SmallVector arg_kinds_vector; + TF_RETURN_IF_ERROR(ParseArgumentKinds(arg_kinds_str, arg_kinds_vector)); + + if (input_shapes_vector.empty()) + input_shapes_vector.resize(dtypes_vector.size()); + + if (arg_kinds_vector.empty()) + arg_kinds_vector.resize(input_shapes_vector.size(), + XlaArgument::Kind::kParameter); + + if (input_shapes_vector.size() != dtypes_vector.size() || + input_shapes_vector.size() != arg_kinds_vector.size()) + return errors::InvalidArgument( + "Input shapes, dtypes, and types/kinds must be of the same " + "length, but got ", + input_shapes_vector.size(), ", ", dtypes_vector.size(), ", and ", + arg_kinds_vector.size(), " respectively"); + + xla_arguments.resize(input_shapes_vector.size()); + for (const auto& arg_components : + llvm::zip(xla_arguments, input_shapes_vector, dtypes_vector, + arg_kinds_vector)) { + XlaArgument& arg = std::get<0>(arg_components); + TensorShape shape; + TF_RETURN_IF_ERROR( + TensorShapeUtils::MakeShape(std::get<1>(arg_components), &shape)); + arg.shape = std::move(shape); + arg.type = std::get<2>(arg_components); + arg.kind = std::get<3>(arg_components); + } + + return Status::OK(); +} + +} // anonymous namespace + +static mlir::LogicalResult MlirTfToHloTextTranslateFunction( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + if (!module_op) return mlir::failure(); + + llvm::SmallVector arg_shapes; + auto args_status = + ParseArgumentShapes(mlir::StringRefToView(input_shapes), arg_shapes); + if (!args_status.ok()) { + LOG(ERROR) << args_status.ToString(); + return mlir::failure(); + } + + XlaCompilationResult compilation_result; + auto compilation_status = CompileMlirToXlaHlo( + module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg, + emit_return_tuple, IdentityShapeRepresentationFn(), &compilation_result, + /*custom_legalization_passes=*/{}); + if (!compilation_status.ok()) { + LOG(ERROR) << "TF/XLA compilation failed: " + << compilation_status.ToString(); + return mlir::failure(); + } + + return PrintHloModuleText(compilation_result, output); +} + +static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + if (!module_op) return mlir::failure(); + + llvm::SmallVector xla_arguments; + auto args_status = ParseXlaArguments( + mlir::StringRefToView(input_shapes), mlir::StringRefToView(input_dtypes), + mlir::StringRefToView(input_types), xla_arguments); + if (!args_status.ok()) { + LOG(ERROR) << args_status.ToString(); + return mlir::failure(); + } + + XlaCompilationResult compilation_result; + auto compilation_status = CompileGraphToXlaHlo( + module_op, xla_arguments, /*device_type=*/"XLA_CPU_JIT", + emit_use_tuple_arg, emit_return_tuple, IdentityShapeRepresentationFn(), + &compilation_result, /*custom_legalization_passes=*/{}); + if (!compilation_status.ok()) { + LOG(ERROR) << "TF/XLA compilation failed: " + << compilation_status.ToString(); + return mlir::failure(); + } + + return PrintHloModuleText(compilation_result, output); +} + +static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) { + registry.insert(); +} + +static void RegisterGraphInputDialects(mlir::DialectRegistry& registry) { + RegisterMlirInputDialects(registry); + registry.insert(); +} + +static mlir::OwningModuleRef SerializedMlirStringAttrToMlirModuleTranslate( + llvm::StringRef input, mlir::MLIRContext* context) { + mlir::Attribute attr = mlir::parseAttribute(input, context); + if (!attr || !attr.isa()) { + LOG(ERROR) << "Input is not parsable as a MLIR StringAttr."; + return nullptr; + } + auto str_attr = attr.cast(); + + RegisterMlirInputDialects(context->getDialectRegistry()); + mlir::OwningModuleRef module_ref; + auto status = + DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref); + if (!status.ok()) { + LOG(ERROR) << status.ToString(); + return nullptr; + } + + return module_ref; +} + +static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate( + mlir::ModuleOp module_op, llvm::raw_ostream& output) { + output << "\""; + std::string serialized_module = SerializeMlirModule(module_op); + llvm::printEscapedString(serialized_module, output); + output << "\""; + return mlir::success(); +} + +} // namespace tensorflow + +static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate( + "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction, + tensorflow::RegisterMlirInputDialects); + +static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate( + "mlir-tf-graph-to-hlo-text", + tensorflow::MlirTfGraphToHloTextTranslateFunction, + tensorflow::RegisterGraphInputDialects); + +static mlir::TranslateToMLIRRegistration SerializedMlirStringAttrToMlirModule( + "mlir-tf-str-attr-to-mlir", + tensorflow::SerializedMlirStringAttrToMlirModuleTranslate); + +static mlir::TranslateFromMLIRRegistration MlirModuleToSerializedMlirStringAttr( + "mlir-tf-mlir-to-str-attr", + tensorflow::MlirModuleToSerializedMlirStringAttrTranslate, + tensorflow::RegisterMlirInputDialects); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 843d491c330..3516e3a65d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -374,9 +374,8 @@ GetGeneralTPUExecutionDeviceAssignment( return (x + bound_x * (y + bound_y * z)) * bound_core + core; }; - std::vector used_device_ids( - location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), - false); + std::vector used_device_ids(bound_x * bound_y * bound_z * bound_core, + false); TPUDevicesAndHosts devices_and_hosts( num_replicas, llvm::SmallVector( num_cores_per_replica, TPUDeviceAndHost())); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index b23fbe7d73c..19eb5b2c476 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -625,8 +625,8 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { } TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -641,8 +641,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -662,8 +662,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -682,8 +682,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -702,8 +702,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -725,8 +725,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -750,8 +750,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); @@ -777,8 +777,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) { } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) { - mlir::registerDialect(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc new file mode 100644 index 00000000000..0647d42f315 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { + +WalkStage::WalkStage(mlir::Operation *op) + : num_regions_(op->getNumRegions()), next_region_(0) {} + +namespace detail { + +/// Walk all of the operations nested under and including the given operations. +void WalkOperations(mlir::Operation *op, VoidCallback callback) { + WalkStage stage(op); + + for (auto ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + callback(op, stage); + stage.Advance(); + + for (auto &block : region) + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : llvm::make_early_inc_range(block)) + WalkOperations(&nestedOp, callback); + } + + // Invoke callback after all regions have been visited. + callback(op, stage); +} + +/// Walk all of the operations nested under and including the given operations. +/// This methods walks operations until an interrupt signal is received. +mlir::WalkResult WalkOperations(mlir::Operation *op, + InterruptCallback callback) { + WalkStage stage(op); + + for (auto ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + if (callback(op, stage).wasInterrupted()) + return mlir::WalkResult::interrupt(); + + stage.Advance(); + + for (auto &block : region) { + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : llvm::make_early_inc_range(block)) + if (WalkOperations(&nestedOp, callback).wasInterrupted()) + return mlir::WalkResult::interrupt(); + } + } + return callback(op, stage); +} + +} // namespace detail +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h new file mode 100644 index 00000000000..31c1f4b62e6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h @@ -0,0 +1,168 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_ + +#include + +#include "mlir/IR/Visitors.h" // from @llvm-project + +// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The +// walk() utility that MLIR core provides traverses operations in a block/ +// blocks in a region in the program order, and these walkers do the same. When +// operations have regions attached to them, the core MLIR walkers visit the +// regions attached to an Op first, and then visit the op. So within the context +// of a single Op, the traversal is post-order (considering the Op as the parent +// node and regions as the children). For certain use cases, it may be more +// efficient/desirable to visit the parent Op before visiting the attached +// regions. As an example, if the attached regions have region arguments that +// are related to the operation inputs (tf.WhileRegion is an example), then we +// may want to propagate some information from the Op inputs to the region +// inputs and then visit the regions to continue progagating that information +// within the regions. With just post-order traversal, to acheive the same we +// may need to schedule another walk so make sure child regions get visited. +// A pre-order walk (within the context of a single operation) will avoid that. +// Similarly, for certain operations, we may want to visit the Op both before +// and after all regions have been visited (say to propagate information from +// inputs -> region arguments and then from region results -> outputs). + +// In general, since the data flow between an operation and its regions is +// opaque in MLIR, we may need to visit the operation in-between regions as well +// if say region0 is transferring control back to the Op and from then to +// region1. So a more general walker that supports pre/in/post-order walk is +// desirable. To support this, the generic walkers defined below will invoke +// the walk callback on the parent Op at each stage of the child region walk, +// i.e., before visiting any region, in between regions, and after visiting all +// regions. To indicate the current walk stage, the callback will also get a +// `WalkState` parameter. The callback can inspect the current walk stage and +// decide to take appropriate actions (incuding not doing anything). With this +// the walker below can support pre/in/post-order walks as well as combined +// walks (pre+in+post)-order walk. + +namespace tensorflow { + +// A class to indicate the current walk stage. +class WalkStage { + public: + explicit WalkStage(mlir::Operation *op); + + bool IsBeforeAllRegions() const { return next_region_ == 0; } + bool IsBeforeRegion(int region) const { return next_region_ == region; } + bool IsAfterRegion(int region) const { return next_region_ == region + 1; } + bool IsAfterAllRegions() const { return next_region_ == num_regions_; } + void Advance() { next_region_++; } + int GetNextRegion() const { return next_region_; } + + private: + const int num_regions_; + int next_region_; +}; + +namespace detail { +// This is similar to MLIR version, but works with multiple argument functions. +// Helper templates to deduce the first argument of a callback parameter. +template +Arg first_argument_type(Ret (*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...) const); +template +decltype(first_argument_type(&F::operator())) first_argument_type(F); + +/// Type definition of the first argument to the given callable 'T'. +template +using first_argument = decltype(first_argument_type(std::declval())); + +using VoidCallback = + llvm::function_ref; +using InterruptCallback = + llvm::function_ref; + +// Walk all of the operations nested under and including the given operation. +void WalkOperations(mlir::Operation *op, VoidCallback callback); + +// Walk all of the operations nested under and including the given operation. +// This methods walks operations until an interrupt result is returned by the +// callback. +mlir::WalkResult WalkOperations(mlir::Operation *op, + InterruptCallback callback); + +} // namespace detail + +// Walk all of the operations nested under and including the given operation. +// This method is selected for stage-aware callbacks that operate on Operation*. +// +// Example: +// tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + return detail::WalkOperations( + op, llvm::function_ref(callback)); +} + +// Walk all of the operations of type 'ArgT' nested under and including the +// given operation. This method is selected for void returning callbacks that +// operate on a specific derived operation type. +// +// Example: +// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto derivedOp = llvm::dyn_cast(op)) callback(derivedOp, stage); + }; + return detail::WalkOperations(op, + static_cast(wrapperFn)); +} + +// Walk all of the operations of type 'ArgT' nested under and including the +// given operation. This method is selected for WalkReturn returning +// interruptible callbacks that operate on a specific derived operation type. +// +// Example: +// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { +// if (some_invariant) +// return WalkResult::interrupt(); +// return WalkResult::advance(); +// }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto derivedOp = llvm::dyn_cast(op)) + return callback(derivedOp, stage); + return mlir::WalkResult::advance(); + }; + return detail::WalkOperations( + op, static_cast(wrapperFn)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 1416ac038d6..e48b14a6bc3 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -13,81 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/AsmState.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" #include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -// NOLINTNEXTLINE -static llvm::cl::opt input_filename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt output_filename( - "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt split_input_file( - "split-input-file", - llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verify_diagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verify_passes( - "verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), - llvm::cl::init(true)); - -// NOLINTNEXTLINE -static llvm::cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), - llvm::cl::init(false)); int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); - // Register various MLIR command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); + mlir::registerAllPasses(); - // Parse pass names in main to ensure static initialization completed. - mlir::PassPipelineCLParser pass_pipeline("", "Compiler passes to run"); - - llvm::cl::ParseCommandLineOptions(argc, argv, - "TF MLIR modular optimizer driver\n"); - - // Set up the input file. - std::string error_message; - auto file = mlir::openInputFile(input_filename, &error_message); - QCHECK(file) << error_message; - - auto output = mlir::openOutputFile(output_filename, &error_message); - QCHECK(output) << error_message; - - if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline, - split_input_file, verify_diagnostics, - verify_passes, allowUnregisteredDialects))) - return 1; - output->keep(); - return 0; + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); + return failed( + mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 8cfdfd01120..3ea92a70ec7 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -111,7 +111,6 @@ int main(int argc, char** argv) { if (import_saved_model_object_graph) { mlir::MLIRContext context; - auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, &context); if (!module_or.status().ok()) return 1; @@ -119,9 +118,8 @@ int main(int argc, char** argv) { module_or.ConsumeValueOrDie()->print(output->os()); } else if (import_saved_model_signature_defs) { mlir::MLIRContext context; - auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, exported_names, &context); + input_filename, tags, exported_names, &context, upgrade_legacy); if (!module_or.status().ok()) return 1; module_or.ConsumeValueOrDie()->print(output->os()); diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 7d3091f921f..b1bf20e4e48 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -68,17 +68,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "tensorflow_js_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":tensorflow_js", - ], - alwayslink = 1, -) - gentbl( name = "tfjs_optimize_inc_gen", tbl_outs = [ @@ -107,7 +96,6 @@ cc_library( ], deps = [ ":tensorflow_js", - ":tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -129,7 +117,6 @@ cc_library( ":tfjs_optimize", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", @@ -149,12 +136,10 @@ cc_library( ], deps = [ ":tensorflow_js", - ":tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:export_utils", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -192,7 +177,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", @@ -236,3 +221,20 @@ tf_cc_binary( "@llvm-project//mlir:Support", ], ) + +tf_cc_binary( + name = "tfjs-opt", + srcs = [ + "tfjs_opt.cc", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_passes", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:StandardOps", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc index 9ba875cdce4..5ea3f51b475 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc @@ -15,18 +15,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" -namespace mlir { -namespace tfjs { - #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc" +namespace mlir { +namespace tfjs { + //===----------------------------------------------------------------------===// // TFJSDialect //===----------------------------------------------------------------------===// -TFJSDialect::TFJSDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void TFJSDialect::initialize() { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 9c98c9b0e19..bc52e3a0c7a 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -29,15 +29,9 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir { -namespace tfjs { - #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_dialect.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h.inc" -} // namespace tfjs -} // namespace mlir - #endif // TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td index 134aa010d8c..e2539c2f6d8 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td @@ -39,7 +39,7 @@ def TFJSDialect : Dialect { TF graphs to be deployed on TFJS. }]; - let cppNamespace = "tfjs"; + let cppNamespace = "::mlir::tfjs"; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfjs/tests/BUILD b/tensorflow/compiler/mlir/tfjs/tests/BUILD index a4ebc997991..5789480c3ba 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/BUILD @@ -3,8 +3,11 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) glob_lit_tests( - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -13,7 +16,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/tfjs:tfjs-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir index 0b7210118df..602f34657a0 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s +// RUN: tfjs-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir index 5f046dc5a8a..f4464ddd01d 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir @@ -1,5 +1,5 @@ // Run optimize pass only and check the results. -// RUN: tf-opt %s -tfjs-optimize | FileCheck %s +// RUN: tfjs-opt %s -tfjs-optimize | FileCheck %s // CHECK-LABEL: prelu_fusion func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { diff --git a/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc new file mode 100644 index 00000000000..c6013128295 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + return failed(mlir::MlirOptMain(argc, argv, "TF JS pass driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc index c03a68471bc..a3678f7d154 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc @@ -37,6 +37,9 @@ namespace { // Optimize TFJS operations in functions. struct Optimize : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 066ca221d5d..1fa224f3ac8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,59 +1,127 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load( + "//tensorflow/core/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) -licenses(["notice"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], - packages = ["//tensorflow/compiler/mlir/..."], + packages = [ + "//tensorflow/compiler/mlir/...", + "//tensorflow/core/kernels/mlir_generated/...", + ], ) cc_library( - name = "cubin_creator", - srcs = ["cubin_creator.cc"], - hdrs = ["cubin_creator.h"], - copts = if_cuda(["-DGOOGLE_CUDA=1"]), + name = "kernel_creator", + srcs = ["kernel_creator.cc"], + hdrs = ["kernel_creator.h"], + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TargetNVVMIR", - "@llvm-project//mlir:Transforms", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:all_passes", + "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", + "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:stream_executor_util", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", + "//tensorflow/compiler/xla/service/mlir_gpu:passes", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", - ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToGPUPass", + "@llvm-project//mlir:SCFToStandard", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:TargetROCDLIR", + "@llvm-project//mlir:Transforms", + ], ) tf_cc_binary( - name = "tf_to_cubin", - srcs = ["tf_to_cubin.cc"], + name = "tf_to_gpu_binary", + srcs = ["tf_to_gpu_binary.cc"], visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], deps = [ - ":cubin_creator", + ":kernel_creator", "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_binary( + name = "tf_to_kernel", + srcs = ["tf_to_kernel.cc"], + visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], + deps = [ + ":kernel_creator", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:X86CodeGen", # fixdeps: keep + "@llvm-project//llvm:X86Disassembler", # fixdeps: keep + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", ], ) @@ -62,15 +130,28 @@ tf_cc_binary( srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"], visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"], deps = [ + "//tensorflow/compiler/mlir/hlo:all_passes", "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", - "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:MlirOptMain", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", ], ) + +exports_files(["tf_framework_c_interface.h"]) + +cc_library( + name = "tf_framework_c_interface", + srcs = ["tf_framework_c_interface.cc"], + hdrs = ["tf_framework_c_interface.h"], + deps = [ + "//tensorflow/core:framework", + "@llvm-project//mlir:mlir_runner_utils", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc deleted file mode 100644 index 1f511e27d9e..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// -// -// This file implements the function to compile a TF kernel function to a cubin. -// -//===----------------------------------------------------------------------===// -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/escaping.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Target/NVVMIR.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" -#include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include "tensorflow/core/platform/cuda_libdevice_path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/path.h" -#if GOOGLE_CUDA -#include "tensorflow/stream_executor/gpu/asm_compiler.h" -#endif - -namespace { -using tensorflow::Status; -using xla::InternalError; -using xla::StatusOr; - -StatusOr GetLibdeviceDir( - const xla::HloModuleConfig& hlo_module_config) { - for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( - hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { - std::string libdevice_dir = - tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - return InternalError( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); -} - -struct MaterializeBroadcastsPass - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::ConversionTarget conversionTarget(getContext()); - mlir::OwningRewritePatternList conversionPatterns; - - // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); - // The conversion uses helpers from the Standard dialect. - conversionTarget.addLegalDialect(); - - mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), - &conversionTarget); - mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), - &conversionPatterns); - - if (failed(applyPartialConversion(getFunction(), conversionTarget, - conversionPatterns))) { - return signalPassFailure(); - } - } -}; - -struct UnfuseBatchNormPass - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::OwningRewritePatternList patterns; - mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); - mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - } -}; - -Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { - mlir::PassManager pm(module.getContext()); - auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { - return VLOG_IS_ON(1); - }; - pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/enable_if_vlog_is_on, - /*printModuleScope=*/false, - /*printAfterOnlyOnChange=*/false, llvm::dbgs()); - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); - pm.addNestedPass( - absl::make_unique()); - pm.addNestedPass(absl::make_unique()); - pm.addPass(mlir::mhlo::createLegalizeToLhloPass( - /*results_escape_functions=*/true)); - pm.addNestedPass(mlir::lmhlo::createLhloCopyRemovalPass()); - - if (failed(pm.run(module))) { - return InternalError("Lowering TF to LHLO failed."); - } - return Status::OK(); -} - -struct PropagateTensorFlowABIKnowledge - : public mlir::PassWrapper> { - explicit PropagateTensorFlowABIKnowledge(mlir::FunctionType type, - llvm::ArrayRef same_shape_) - : func_type(type), same_shape(same_shape_) {} - - void runOnOperation() override { - // We know due to tensorflow ABI that the offset is always 0 and that the - // innermost stride is always 1. To make this visible to the compiler, - // we insert constants into the code and replace usages accordingly. - // We do not change the signature so that we keep a somewhat stable ABI - // that is easy to undertand by tools. - // We also know that tensorflow aligns all allocated pointers by 16, so - // we pass this on. Furthermore, we know that arguments never alias. More - // precicely, they may only alias (due to reuse) if the kernel does not - // read from a position it previously has written to. We express this with - // the noalias attribute. - mlir::LLVM::LLVMFuncOp func = getOperation(); - - // This only works if the function is local and we can rewrite it. - if (func.isExternal()) return; - - mlir::OpBuilder b(func.getBody()); - // Steal the LLVM representation of the index type from the third argument. - auto index_type = func.getArgument(3).getType(); - mlir::Value one = b.create( - func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); - mlir::Value zero = b.create( - func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); - uint32_t arg_pos = 0; - std::vector positions; - // Collect the agument and return types of the surrounding function. - auto arg_types = llvm::to_vector<4>(llvm::concat( - func_type.getInputs(), func_type.getResults())); - for (mlir::Type arg_type : arg_types) { - if (!arg_type.isa()) { - func.emitError() << "argument of surrounding func is not ranked memref"; - signalPassFailure(); - return; - } - positions.push_back(arg_pos); - // Set alignment and aliasing on the pointers. - func.setArgAttr(arg_pos + 1, "llvm.noalias", b.getBoolAttr(true)); - func.setArgAttr(arg_pos + 1, "llvm.align", b.getIndexAttr(16)); - // Replace the offset with zero. Offset is argument number 3. - func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); - // Forward over base_ptr, aligned_ptr, offset, size and stride arguments. - arg_pos += 3 + arg_type.cast().getRank() * 2; - // Replace the last stride with constant 1. - func.getArgument(arg_pos - 1).replaceAllUsesWith(one); - } - - // If we have knowledge that some arguments have the same shape, we - // can use that here. Simply replace usages of the shape parameters within - // the function body to a single shape parameter. - if (!same_shape.empty()) { - auto first = same_shape.front(); - auto first_offset = positions.at(first); - auto first_type = arg_types[first].cast(); - uint32_t rank = first_type.getRank(); - for (auto same : same_shape.drop_front(1)) { - uint32_t same_offset = positions.at(same); - auto same_type = arg_types[same].cast(); - if (same_type.getRank() != rank) { - func.emitOpError() << "same shape constraints on arguments with " - "non-matching shapes: #" - << first << " and #" << same; - signalPassFailure(); - continue; - } - - for (uint32_t i = 0; i < 2 * rank; ++i) { - // Replace uses for second arg data with first arg. - auto same_arg = func.getArgument(same_offset + 3 + i); - auto first_arg = func.getArgument(first_offset + 3 + i); - same_arg.replaceAllUsesWith(first_arg); - } - } - } - } - - mlir::FunctionType func_type; - llvm::ArrayRef same_shape; -}; - -Status PropagateTensorFlowABIKnowledgeToKernel( - mlir::ModuleOp module, llvm::ArrayRef same_shape) { - // Grab the original signature from the single function. - auto func = *module.getBody()->op_begin(); - - mlir::PassManager pm(module.getContext()); - auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { - return VLOG_IS_ON(1); - }; - pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/enable_if_vlog_is_on, - /*printModuleScope=*/false, - /*printAfterOnlyOnChange=*/false, llvm::dbgs()); - auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); - kernel_pm.addNestedPass( - absl::make_unique(func.getType(), - same_shape)); - - if (failed(pm.run(module))) { - return InternalError("Static knowledge propagation failed."); - } - return Status::OK(); -} - -void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - return true; - }(); - (void)init_once; -} -} // namespace - -StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( - llvm::StringRef tf_code, std::pair compute_capability, - llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, - llvm::ArrayRef unroll_factors) { - RegisterDialects(); - mlir::MLIRContext context; - mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); - - TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); - { - xla::mlir_gpu::LowerLHLOToGPUOptions options; - options.tile_sizes = tile_sizes; - options.unroll_factors = unroll_factors; - options.collapse_parallel_loops = false; - options.use_approximations = true; - TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options)); - } - TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); - TF_RETURN_IF_ERROR( - PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape)); - - mlir::OwningModuleRef kernel_module = - xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); - if (!llvmModule) { - return InternalError("Could not translate MLIR module to NVVM"); - } - - llvmModule->setModuleIdentifier("acme"); - llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); - - xla::HloModuleConfig config; - config.set_debug_options(xla::GetDebugOptionsFromFlags()); - - auto enable_fusion = [](llvm::TargetMachine* target) { - target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; - }; - - TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); - TF_ASSIGN_OR_RETURN( - std::string ptx, - xla::gpu::nvptx::CompileToPtx(llvmModule.get(), compute_capability, - config, libdevice_dir, enable_fusion)); - VLOG(1) << ptx; - -#if GOOGLE_CUDA - return tensorflow::se::CompileGpuAsm( - std::get<0>(compute_capability), std::get<1>(compute_capability), - ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); -#else - return InternalError( - "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); -#endif -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 3a28d4815d2..29939f227db 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -35,13 +35,3 @@ cc_library( "@llvm-project//mlir:SideEffects", ], ) - -cc_library( - name = "tf_framework_dialect_registration", - srcs = ["dialect_registration.cc"], - deps = [ - ":tf_framework_ops", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index e67b5fd7f85..b3d92773be4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -24,8 +24,7 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -TFFrameworkDialect::TFFrameworkDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void TFFrameworkDialect::initialize() { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" @@ -49,19 +48,23 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - case TFFrameworkTypes::OpKernelContextType: - os << "op_kernel_context"; - return; - default: - llvm_unreachable("unexpected TF Framework type kind"); + if (type.isa()) { + os << "op_kernel_context"; + return; } + llvm_unreachable("unexpected TF Framework type kind"); +} + +template +LogicalResult Verify(OpTy op) { + return success(); } //===----------------------------------------------------------------------===// // AllocRawOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(AllocRawOp op) { +template <> +LogicalResult Verify(AllocRawOp op) { // Check that the total number of operands matches the number of dynamic // dimensions specified in the memref type. unsigned result_dyn_dims = op.getType().getNumDynamicDims(); @@ -74,14 +77,9 @@ static LogicalResult Verify(AllocRawOp op) { return success(); } -//===----------------------------------------------------------------------===// -// DeallocRawOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(DeallocRawOp op) { return success(); } - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" - } // namespace tf_framework } // namespace kernel_gen } // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h index 8d6e433d9b9..aab090cc5e0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -30,35 +30,20 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -namespace TFFrameworkTypes { -enum Kind { - OpKernelContextType = Type::FIRST_TF_FRAMEWORK_TYPE, -}; -} // namespace TFFrameworkTypes - /// OpKernelContextType corresponds to C++ class OpKernelContext defined in /// tensorflow/core/framework/op_kernel.h class OpKernelContextType : public Type::TypeBase { public: using Base::Base; - - static OpKernelContextType get(MLIRContext *context) { - return Base::get(context, TFFrameworkTypes::Kind::OpKernelContextType); - } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == TFFrameworkTypes::Kind::OpKernelContextType; - } }; -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" -#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" - } // namespace tf_framework } // namespace kernel_gen } // namespace mlir +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" + #endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 65481ad377f..e6e29bcbdc2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -29,7 +29,7 @@ def TFFramework_Dialect : Dialect { This dialect contains operations and types for that correspond to TensorFlow C++ Framework. }]; - let cppNamespace = "kernel_gen::tf_framework"; + let cppNamespace = "::mlir::kernel_gen::tf_framework"; } def TFFramework_OpKernelContextType : DialectType traits = []> : Op { - let verifier = "return Verify(*this);"; + let verifier = "return Verify<$cppClass>(*this);"; } //===----------------------------------------------------------------------===// @@ -111,4 +111,15 @@ def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw", let assemblyFormat = "`(` $ctx `,` $memref `)` attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// NullContextOp +//===----------------------------------------------------------------------===// +def TFFramework_NullContextOp : TFFramework_Op<"null_context", + [NoSideEffect]> { + let summary = "Creates a fake TF context that will be lowered to nullptr"; + let description = [{Needed for testing}]; + let results = (outs TFFramework_OpKernelContextType:$result); + let assemblyFormat = "`(` `)` attr-dict `:` type($result)"; +} + #endif // TF_FRAMEWORK_OPS diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc new file mode 100644 index 00000000000..68d1d581351 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -0,0 +1,243 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- kernel_creator.cc ----------------------------------------*- C++ -*-===// +// +// This file implements the function to compile a TF kernel function to gpu +// binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side. +// +//===----------------------------------------------------------------------===// +#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project +#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" + +namespace tensorflow { +namespace kernel_gen { +namespace { + +using tensorflow::Status; +using xla::InternalError; +using xla::StatusOr; + +constexpr llvm::StringRef kGpuBinaryAttrName = "gpu.binary"; + +Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, + llvm::ArrayRef tile_sizes, + llvm::ArrayRef unroll_factors) { + mlir::PassManager pm(module.getContext()); + applyPassManagerCLOptions(pm); + + pm.addPass(mlir::mhlo::createLegalizeTFPass(false)); + if (gpu_binary_only) { + pm.addNestedPass( + mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass()); + pm.addNestedPass( + mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass()); + pm.addPass(mlir::mhlo::createLegalizeToLhloPass( + /*results_escape_functions=*/true)); + // Moving `AllocOp`s and inserting missing `DeallocOp`s + pm.addPass(::mlir::createBufferPlacementPass()); + pm.addNestedPass(mlir::createCopyRemovalPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); + } else { + pm.addPass(mlir::createTransformUnrankedHloPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); + pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass()); + } + + // Clean up the IR for further processing. + pm.addPass(mlir::createCanonicalizerPass()); + // We have to anticipate later unrolling in tiling to make sure that we get + // the requested tiling after unrolling. Compute the new tiling here if + // needed. + llvm::SmallVector tiling_for_unrolling; + llvm::SmallVector as_int64; + if (!unroll_factors.empty()) { + tiling_for_unrolling.reserve(tile_sizes.size()); + for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { + tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); + as_int64.push_back(std::get<1>(pair)); + } + } else { + tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); + } + // Transform LHLO operations to LinAlg. + pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); + // Fuse linalg operations. + pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass( + /*use_parallel_loops=*/true, tiling_for_unrolling)); + // Transform the Linalg operations inside of the loop nest into parallel + // loops. + pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass()); + // Canonicalize the code to simplify index computations. This is needed so + // that loop bounds have the same value. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Fuse the inner-most loops. + pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass()); + // Run CSE to ensure that loads and stores to the same subview get + // recognized as such. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Forward stores to buffers to loads. + pm.addPass(xla::mlir_gpu::createStoreForwardingPass()); + // Remove now unused temporary buffers. + pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass()); + if (!unroll_factors.empty()) { + pm.addPass(::mlir::createParallelLoopTilingPass(as_int64)); + } + // Some basic cleanup. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Greedily map the remaining loop to GPU hardware dimensions. + pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass()); + // Apply the mapping. + pm.addPass(mlir::createParallelLoopToGpuPass()); + + // Embed TF Framework ops. + if (!gpu_binary_only) { + pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); + } + + // Some basic cleanup. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Make loops with min bounds into a conditional plus static bounds. + // Only do this if we unrolled in the first place. + if (!unroll_factors.empty()) { + pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass()); + } + // Approximate Tanh using standard operations. + pm.addNestedPass<::mlir::FuncOp>( + ::mlir::mhlo::createLegalizeTanhToApproximationPass()); + // Move scalar operations into the launch to ensure smaller signatures. + pm.addPass(xla::mlir_gpu::createMoveScalarComputationsIntoGpuLaunchPass()); + // Take launches to launches with kernels. + pm.addPass(::mlir::createGpuKernelOutliningPass()); + + if (gpu_binary_only) { + // Make kernel signature deterministic so that we can call it externally. + pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass()); + } + pm.addPass(::mlir::createLowerAffinePass()); + pm.addPass(::mlir::createLowerToCFGPass()); + if (failed(pm.run(module))) { + return InternalError("Lowering to GPU kernels failed."); + } + return Status::OK(); +} + +Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, + llvm::ArrayRef same_shape, + llvm::StringRef gpu_binary_attr_name, + int32_t architecture) { + mlir::PassManager pm(module.getContext()); + applyPassManagerCLOptions(pm); + + auto& kernel_pm = pm.nest(); + if (gpu_binary_only) { + // Grab the original signature from the single function. + auto func = *module.getBody()->op_begin(); + kernel_pm.addNestedPass( + mlir::kernel_gen::transforms::CreatePropagateTensorFlowABIKnowledgePass( + func.getType(), same_shape)); + } + kernel_pm.addPass(mlir::createStripDebugInfoPass()); + kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass( + gpu_binary_attr_name, architecture)); + + if (!gpu_binary_only) { + pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + } + return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.") + : Status::OK(); +} + +} // namespace + +StatusOr GenerateKernelForTfCode( + mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); + TF_RETURN_IF_ERROR( + LowerTFtoGPU(module.get(), gpu_binary_only, tile_sizes, unroll_factors)); +#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA) + return InternalError( + "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." + " Did you specify either --config=rocm or --config=cuda ?"); +#endif + +#if TENSORFLOW_USE_ROCM + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get())); +#elif GOOGLE_CUDA + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); +#endif + TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape, + kGpuBinaryAttrName, architecture)); + return module; +} + +StatusOr ExtractGpuBinary(mlir::ModuleOp module) { + auto gpu_modules = module.getOps(); + if (std::distance(gpu_modules.begin(), gpu_modules.end()) != 1) { + return InternalError("There should be exactly one GPU Module"); + } + mlir::gpu::GPUModuleOp gpu_mod = *gpu_modules.begin(); + auto blob = gpu_mod.getAttrOfType(kGpuBinaryAttrName); + if (blob == nullptr) { + return InternalError("No binary blob found in the module"); + } + return blob.getValue().str(); +} + +} // namespace kernel_gen +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h similarity index 53% rename from tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h rename to tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 47626ba9d0d..b168ec815de 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -13,30 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +//===- kernel_creator.h -----------------------------------------*- C++ -*-===// // -// This file declares the function to compile a TF kernel function to a cubin. +// This file declares the function to compile a TF kernel function to gpu +// binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side. // //===----------------------------------------------------------------------===// -#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ #include -#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { namespace kernel_gen { -xla::StatusOr> GenerateCubinForTfCode( - llvm::StringRef tf_code, - std::pair compute_capability = {7, 5}, - llvm::ArrayRef tile_sizes = {16, 64}, + +// Converts TF code to LLVM/NVVM. If `gpu_binary_only` is true, then the +// conversion stops after gpu_binary blob is generated. If `gpu_binary_only` is +// false, lowers the host side to LLVM Dialect. +xla::StatusOr GenerateKernelForTfCode( + mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + int32_t architecture = 75, llvm::ArrayRef tile_sizes = {16, 64}, llvm::ArrayRef same_shape = {}, llvm::ArrayRef unroll_factors = {}); + +// Extracts gpu_binary from the converted module. +xla::StatusOr ExtractGpuBinary(mlir::ModuleOp module); + } // namespace kernel_gen } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc new file mode 100644 index 00000000000..e75db59d885 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +using tensorflow::Allocator; + +Allocator* GetAllocator(void* op_kernel_ctx) { + auto* ctx = static_cast(op_kernel_ctx); + // TODO(pifon): Figure out how to set AllocatorAttributes correctly. + tensorflow::AllocatorAttributes attrs; + return ctx->get_allocator(attrs); +} + +} // namespace + +extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx, + size_t num_bytes) { + return GetAllocator(op_kernel_ctx) + ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes); +} + +extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) { + GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h new file mode 100644 index 00000000000..143ebc95932 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ + +#include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw( + void* op_kernel_ctx, size_t num_bytes); + +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw( + void* op_kernel_ctx, void* ptr); + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc similarity index 61% rename from tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index 96831689600..c7cb92404f5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===// // -// This file implements the entry point to compile a tf op to a cubin file. +// This file implements the entry point to compile a tf op to a gpu binary // //===----------------------------------------------------------------------===// #include @@ -23,10 +23,44 @@ #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +namespace { + +xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, + int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + // Read TF code. + std::string tf_code; + TF_RETURN_IF_ERROR( + ReadFileToString(Env::Default(), input_file.str(), &tf_code)); + // Compile. + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN( + mlir::OwningModuleRef module, + GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true, + architecture, tile_sizes, same_shape, + unroll_factors)); + // Extract gpu_binary. + TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module)); + + // Write gpu_binary blob. + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), output_file.str(), gpu_binary)); + return xla::Status::OK(); +} + +} // namespace +} // namespace kernel_gen +} // namespace tensorflow int main(int argc, char** argv) { llvm::cl::opt input_file("input", llvm::cl::desc("input file"), @@ -51,38 +85,15 @@ int main(int argc, char** argv) { llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); tensorflow::InitMlir y(&argc, &argv); + mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); - std::pair compute_capability(architecture / 10, - architecture % 10); - - std::string tf_code; - auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(), - input_file, &tf_code); - if (!read_status.ok()) { - LOG(ERROR) << read_status; - return 1; - } - - auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( - tf_code, compute_capability, tile_sizes, same_shape, unroll_factors); - - if (!cubin.ok()) { - LOG(ERROR) << cubin.status(); - return 1; - } - - std::vector cubin_data = cubin.ConsumeValueOrDie(); - - auto status = tensorflow::WriteStringToFile( - tensorflow::Env::Default(), output_file, - absl::string_view{reinterpret_cast(cubin_data.data()), - cubin_data.size()}); - + auto status = + tensorflow::kernel_gen::Run(input_file, output_file, architecture, + tile_sizes, same_shape, unroll_factors); if (!status.ok()) { LOG(ERROR) << status; return 1; } - return 0; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc new file mode 100644 index 00000000000..2caa806551e --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -0,0 +1,161 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- tf_to_kernel.cc ------------------------------------------*- C++ -*-===// +// +// This file implements the entry point to compile a tf op to a kernel. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +namespace { + +static llvm::codegen::RegisterCodeGenFlags CGF; + +std::unique_ptr GetTargetMachine(llvm::Module* module) { + llvm::Triple triple(module->getTargetTriple()); + if (triple.getTriple().empty()) { + triple = llvm::Triple(llvm::sys::getDefaultTargetTriple()); + module->setTargetTriple(triple.getTriple()); + } + + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget("", triple, error); + if (!target) { + return nullptr; + } + + llvm::TargetOptions target_options = + llvm::codegen::InitTargetOptionsFromCodeGenFlags(); + return std::unique_ptr(target->createTargetMachine( + triple.str(), "generic", "", target_options, llvm::Reloc::Model::PIC_)); +} + +// Compiles the given MLIR module via LLVM into an executable binary format. +xla::StatusOr EmitToBinary(mlir::ModuleOp module) { + // Translate the module. + llvm::LLVMContext llvm_context; + std::unique_ptr llvm_module = + mlir::translateModuleToLLVMIR(module, llvm_context); + + // Set up the output stream. + llvm::SmallString<8> outstr; + llvm::raw_svector_ostream ostream(outstr); + ostream.SetUnbuffered(); + + llvm::legacy::PassManager codegen_passes; + codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(llvm_module->getTargetTriple()))); + + // TODO(b/163818770): Apply optimizations before dumping .a file. + auto target_machine = GetTargetMachine(llvm_module.get()); + llvm_module->setDataLayout(target_machine->createDataLayout()); + if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr, + llvm::CGFT_ObjectFile, false)) { + return xla::InternalError("Failed add passes to emit file"); + } + codegen_passes.run(*llvm_module); + return ostream.str().str(); +} + +xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, + int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + // Read TF code. + std::string tf_code; + TF_RETURN_IF_ERROR( + ReadFileToString(Env::Default(), input_file.str(), &tf_code)); + // Compile. + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN( + mlir::OwningModuleRef module, + GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false, + architecture, tile_sizes, same_shape, + unroll_factors)); + // Get binary. + TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); + + // Write .a file. + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), output_file.str(), binary)); + return xla::Status::OK(); +} + +} // namespace +} // namespace kernel_gen +} // namespace tensorflow + +int main(int argc, char** argv) { + llvm::cl::opt input_file("input", llvm::cl::desc("input file"), + llvm::cl::value_desc("filename"), + llvm::cl::init("foo.mlir")); + llvm::cl::opt output_file( + "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), + llvm::cl::init("foo.bin")); + llvm::cl::opt architecture( + "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), + llvm::cl::init(50)); + llvm::cl::list tile_sizes( + "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, + llvm::cl::CommaSeparated); + llvm::cl::list unroll_factors( + "unroll_factors", + llvm::cl::desc("factors to unroll by, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); + llvm::cl::list same_shape( + "same_shape", + llvm::cl::desc("arguments with same shape, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); + + tensorflow::InitMlir y(&argc, &argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerPassManagerCLOptions(); + llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); + + auto status = + tensorflow::kernel_gen::Run(input_file, output_file, architecture, + tile_sizes, same_shape, unroll_factors); + if (!status.ok()) { + LOG(ERROR) << status; + return 1; + } + return 0; +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index c1af35617b1..85f1fafd436 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -13,110 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/AsmState.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Support/MlirOptMain.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -// NOLINTNEXTLINE -static llvm::cl::opt inputFilename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt outputFilename( - "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt splitInputFile( - "split-input-file", - llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyDiagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyPasses( - "verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), - llvm::cl::init(true)); - -// NOLINTNEXTLINE -static llvm::cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt showDialects( - "show-dialects", llvm::cl::desc("Print the list of registered dialects"), - llvm::cl::init(false)); - int main(int argc, char **argv) { - mlir::registerAllDialects(); mlir::registerAllPasses(); - - mlir::mhlo::registerAllDialects(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); mlir::kernel_gen::registerKernelGenPasses(); - llvm::InitLLVM y(argc, argv); + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + registry.insert(); - // Register any pass manager command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerPassManagerCLOptions(); - mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); - - // Parse pass names in main to ensure static initialization completed. - llvm::cl::ParseCommandLineOptions(argc, argv, - "MLIR modular optimizer driver\n"); - - if (showDialects) { - mlir::MLIRContext context; - llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { - llvm::outs() << dialect->getNamespace() << "\n"; - } - return 0; - } - - // Set up the input file. - std::string errorMessage; - auto file = mlir::openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - auto output = mlir::openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - exit(1); - } - - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects))) { - return 1; - } - // Keep the output file if the invocation of MlirOptMain was successful. - output->keep(); - return 0; + return failed( + mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 0d346da9956..b853dea39d2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -1,4 +1,12 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") +load( + "//tensorflow/core/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) package( default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], @@ -20,6 +28,21 @@ cc_library( ], ) +cc_library( + name = "bufferize", + srcs = ["bufferize.cc"], + hdrs = ["rewriters.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "embed_tf_framework", srcs = ["embed_tf_framework.cc"], @@ -36,7 +59,7 @@ cc_library( ) gentbl( - name = "tf_framework_passes_inc_gen", + name = "kernel_gen_passes_inc_gen", tbl_outs = [("-gen-pass-decls -name KernelGen", "kernel_gen_passes.h.inc")], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", @@ -46,28 +69,57 @@ gentbl( cc_library( name = "passes", srcs = [ + "bufferize_pass.cc", "embed_tf_framework_pass.cc", + "gpu_kernel_to_blob_pass.cc", + "materialize_broadcasts_pass.cc", + "propagate_tf_abi_knowledge_pass.cc", "shape_to_descriptors_pass.cc", - "tf_framework_legalize_to_llvm_pass.cc", + "tf_kernel_to_llvm_pass.cc", + "unfuse_batch_norm_pass.cc", ], hdrs = ["passes.h"], + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ + "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + ":bufferize", ":embed_tf_framework", + ":kernel_gen_passes_inc_gen", ":tf_framework_legalize_to_llvm", - ":tf_framework_passes_inc_gen", - "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", - "@llvm-project//mlir:ShapeToSCF", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:TargetROCDLIR", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - ], + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/xla/service/gpu:stream_executor_util", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + ] + if_cuda_is_configured([ + "//tensorflow/stream_executor/gpu:asm_compiler", + ]) + if_rocm_is_configured([ + "//tensorflow/core/platform:rocm_rocdl_path", + ]), ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc new file mode 100644 index 00000000000..45b8c524650 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -0,0 +1,188 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for translating mixed IR to buffer form. + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace transforms { + +namespace { + +class TensorFromElementsOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorFromElementsOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ShapedType result_type = op.getType().cast(); + int number_of_elements = op.elements().size(); + MemRefType memref_type = + MemRefType::get({number_of_elements}, result_type.getElementType()); + Value result = rewriter.create(loc, memref_type); + for (auto operand : llvm::enumerate(operands)) { + Value index = rewriter.create(loc, operand.index()); + rewriter.create(loc, operand.value(), result, index); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class DynamicTensorFromElementsOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + DynamicTensorFromElementsOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + DynamicTensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Allocate memory on stack. + Location loc = op.getLoc(); + DynamicTensorFromElementsOp::Adaptor transformed(operands); + RankedTensorType tensor_ty = op.getType().cast(); + MemRefType memref_type = + MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType()); + Value result = rewriter.create(loc, memref_type, + transformed.dynamicExtents()); + + // Collect loop bounds. + int64_t rank = tensor_ty.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lower_bounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upper_bounds; + int next_dynamic_index = 0; + for (int i = 0; i < rank; i++) { + Value ub = tensor_ty.isDynamicDim(i) + ? transformed.dynamicExtents()[next_dynamic_index++] + : rewriter.create( + loc, memref_type.getDimSize(i)); + upper_bounds.push_back(ub); + } + + // Generate tensor elements. + rewriter.create( + loc, lower_bounds, upper_bounds, steps, + [&](OpBuilder &b, Location loc, ValueRange ivs) { + BlockAndValueMapping mapping; + mapping.map(op.body().getArguments(), ivs); + for (auto &nested_op : op.getBody()->without_terminator()) + b.clone(nested_op, mapping); + auto yield_op = llvm::cast(op.getBody()->getTerminator()); + b.create(loc, mapping.lookup(yield_op.value()), result, ivs); + b.create(loc); + }); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class TensorLoadOpConversion + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorLoadOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + TensorLoadOpAdaptor adaptor(operands); + rewriter.replaceOp(op, {adaptor.memref()}); + return success(); + } +}; + +class ExtractElementOpConversion + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + ExtractElementOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + ExtractElementOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ExtractElementOpAdaptor adaptor(operands); + + if (!adaptor.aggregate().getType().isa()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), + adaptor.indices()); + return success(); + } +}; + +class TensorCastOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorCastOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensor_ty = op.getType().dyn_cast(); + if (!tensor_ty) return failure(); + + Value arg = operands.front(); + auto arg_ty = arg.getType().dyn_cast(); + if (!arg_ty) return failure(); + + auto result_ty = converter->convertType(tensor_ty); + rewriter.replaceOpWithNewOp(op, arg, result_ty); + + return success(); + } +}; + +} // namespace + +void populateStandardBufferizePattern(MLIRContext *context, + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert(context, converter); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc new file mode 100644 index 00000000000..8ddbb15219f --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for translating mixed IR to buffer form. +// Currently it supports MHLO and some operations from the Standard dialect. + +#include + +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +// TODO(herhut) : This could become a real pattern in bufferize pass. What we +// would need to do is insert a copy to model the semantics correctly. The same +// is true for the TensorLoad pattern that is already in there. Then buffer +// assignment free insertion and copy removal should clean this up for us. +// +// This patten erases `tensor_store(src_unranked_tensor, dst_unranked_memref)` +// op and replaces the result of the defining op produced `dst_unranked_memref` +// with the rewritten `src_unranked_tensor`. +class UnrankedTensorStoreTestOnlyPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::TensorStoreOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOp(op.memref().getDefiningOp(), op.tensor()); + rewriter.replaceOp(op, {}); + return success(); + } +}; + +struct BufferizePass : public BufferizePassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + void runOnOperation() override { + auto& context = getContext(); + ConversionTarget target(context); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([&](TensorStoreOp op) { + return !op.tensor().getType().isa(); + }); + + BufferAssignmentTypeConverter converter; + auto typesAreLegal = [&converter](Operation* op) { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }; + target.addDynamicallyLegalOp([&](FuncOp op) { + auto inputs = op.getType().getInputs(); + auto results = op.getType().getResults(); + return converter.isLegal(inputs) && converter.isLegal(results) && + converter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp(typesAreLegal); + target.addDynamicallyLegalOp(typesAreLegal); + + OwningRewritePatternList patterns; + mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); + populateWithBufferAssignmentOpConversionPatterns( + &context, &converter, &patterns); + populateStandardBufferizePattern(&context, &converter, &patterns); + patterns.insert(&context); + + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr > CreateBufferizePass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index a0cfcae65d1..6aea4d9c619 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry"; // * std.dealloc becomes tf_framework.dealloc_raw. class EmbedTFFrameworkPass : public EmbedTFFrameworkPassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); @@ -68,7 +72,7 @@ class EmbedTFFrameworkPass } // namespace -std::unique_ptr > createEmbedTFFrameworkPass() { +std::unique_ptr > CreateEmbedTFFrameworkPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc new file mode 100644 index 00000000000..773e12f2da3 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Target/ROCDLIR.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm_rocdl_path.h" +#endif + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +using xla::InternalError; + +class GpuKernelToBlobPass + : public GpuKernelToBlobPassBase { + public: + GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) { + blob_annotation_ = blob_annotation; + arch_ = arch; + } + + void runOnOperation() override { + mlir::gpu::GPUModuleOp gpu_module = getOperation(); + auto blob_or = GetGpuBinaryBlob(gpu_module); + if (blob_or.ok()) { + const auto& blob = blob_or.ValueOrDie(); + std::string blob_string(blob.begin(), blob.end()); + gpu_module.setAttr(blob_annotation_, + mlir::StringAttr::get(blob_string, &getContext())); + return; + } + return signalPassFailure(); + } + + xla::StatusOr> GetGpuBinaryBlob( + mlir::gpu::GPUModuleOp gpu_module) { + llvm::LLVMContext llvmContext; +#if TENSORFLOW_USE_ROCM + auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to ROCDL IR"); + } + + llvmModule->setModuleIdentifier("acme"); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + std::string libdevice_dir = tensorflow::RocdlRoot(); + + return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config, + libdevice_dir); + +#elif GOOGLE_CUDA + auto llvmModule = mlir::translateModuleToNVVMIR(gpu_module, llvmContext); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to NVVM"); + } + + llvmModule->setModuleIdentifier("acme"); + llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + auto enable_fusion = [](llvm::TargetMachine* target) { + target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; + }; + + int32_t cc_major = arch_ / 10; + int32_t cc_minor = arch_ % 10; + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx(llvmModule.get(), + std::make_pair(cc_major, cc_minor), + config, libdevice_dir, enable_fusion)); + VLOG(1) << ptx; + + return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(), + xla::gpu::PtxOptsFromConfig(config)); +#endif + return InternalError( + "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." + " Did you specify either --config=rocm or --config=cuda ?"); + } + + private: + xla::StatusOr GetLibdeviceDir( + const xla::HloModuleConfig& hlo_module_config) { + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { + std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + return InternalError( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); + } +}; + +} // namespace + +std::unique_ptr> CreateGpuKernelToBlobPass( + mlir::StringRef blob_annotation, int32_t architecture) { + return std::make_unique(blob_annotation, architecture); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc new file mode 100644 index 00000000000..dd3f32e2b3c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct MaterializeBroadcastsPass + : public MaterializeBroadcastsPassBase { + void runOnFunction() override { + mlir::ConversionTarget conversionTarget(getContext()); + mlir::OwningRewritePatternList conversionPatterns; + + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateMaterializeBroadcastsPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 13f367c9fe4..179059e54eb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -25,23 +27,41 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -// Test pass for applying TF Framework -> LLVM patterns. -std::unique_ptr > -createTestTFFrameworkLegalizeToLLVMPass(); - // Pass to replace some of the Standard ops with TF Framework ops. // * adds tf_framework::OpKernelContextType argument to the function // * std.alloc becomes tf_framework.alloc_raw // * std.dealloc becomes tf_framework.dealloc_raw -std::unique_ptr > createEmbedTFFrameworkPass(); +std::unique_ptr > CreateEmbedTFFrameworkPass(); } // namespace tf_framework namespace transforms { +// Pass for applying LLVM legalization patterns. +std::unique_ptr > CreateTFKernelToLLVMPass(); + // Pass to tranform shape computations in shape dialect to standard and scf // using memref descriptors. -std::unique_ptr CreateShapeToDescriptorsPass(); +std::unique_ptr > CreateShapeToDescriptorsPass(); + +// Pass to tranform computations on values to their corresponding parts on +// buffers. +std::unique_ptr > CreateBufferizePass(); + +// Pass to materialize broadcasts. +std::unique_ptr CreateMaterializeBroadcastsPass(); + +// Pass to propagate TF ABI knowledge, e.g. offsets, alignment. +std::unique_ptr> +CreatePropagateTensorFlowABIKnowledgePass( + mlir::FunctionType type = {}, llvm::ArrayRef same_shape = {}); + +// Pass to annotate GPU Module with its PTX. +std::unique_ptr> CreateGpuKernelToBlobPass( + mlir::StringRef blob_annotation = "", int32_t architecture = 0); + +// Pass to unfuse batch norm. +std::unique_ptr CreateUnfuseBatchNormPass(); } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 61720674926..5264ef3ec94 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -13,25 +13,61 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TF_FRAMEWORK_PASSES -#define TF_FRAMEWORK_PASSES +#ifndef TF_KERNEL_GEN_PASSES +#define TF_KERNEL_GEN_PASSES include "mlir/Pass/PassBase.td" -def TestTFFrameworkLegalizeToLLVMPass - : Pass<"test-tf-framework-legalize-to-llvm", "ModuleOp"> { - let summary = "Test pass for applying TF Framework -> LLVM patterns."; - let constructor = "tf_framework::createTestTFFrameworkLegalizeToLLVMPass()"; +def TFKernelToLLVMPass : Pass<"tf-kernel-to-llvm", "ModuleOp"> { + let summary = "Pass for applying LLVM legalization patterns."; + let constructor = "transforms::CreateTFKernelToLLVMPass()"; } def EmbedTFFrameworkPass : Pass<"embed-tf-framework", "ModuleOp"> { let summary = "Pass to embed TF Framework for allocation and error reporting"; - let constructor = "tf_framework::createEmbedTFFrameworkPass()"; + let constructor = "tf_framework::CreateEmbedTFFrameworkPass()"; } -def ShapeToDescriptorsPass : Pass<"test-shape-to-descriptors", "ModuleOp"> { +def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> { let summary = "Pass to transform shape computations to descriptors"; let constructor = "transforms::CreateShapeToDescriptorsPass()"; } -#endif // TF_FRAMEWORK_PASSES +def BufferizePass : Pass<"bufferize", "ModuleOp"> { + let summary = "Pass to transform operations on values to buffer based ones"; + let constructor = "transforms::CreateBufferizePass()"; +} + +def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> { + let summary = "Pass to materialize broadcasts"; + let constructor = "transforms::CreateMaterializeBroadcastsPass()"; +} + +def UnfuseBatchNormPass : FunctionPass<"unfuse-batch-norm"> { + let summary = "Pass to unfuse batch norm"; + let constructor = "transforms::CreateUnfuseBatchNormPass()"; +} + +def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> { + let summary = "Pass to annotate GPU Module with its PTX"; + let options = [ + Option<"blob_annotation_", "blob-annotation", "mlir::StringRef", + /*default=*/"", "Blob attribute name">, + Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">, + ]; + let constructor = "transforms::CreateGpuKernelToBlobPass()"; +} + +def PropagateTensorFlowABIKnowledgePass + : Pass<"propagate-tf-abi-knowledge", "LLVM::LLVMFuncOp"> { + let summary = "Pass to propagate TF ABI knowledge, e.g. offsets, alignment"; + let options = [ + Option<"func_type_", "func-type", "mlir::FunctionType", + /*default=*/"", "Function type">, + ListOption<"same_shape_", "same-shape", "uint32_t", + "List of same shape args">, + ]; + let constructor = "transforms::CreatePropagateTensorFlowABIKnowledgePass()"; +} + +#endif // TF_KERNEL_GEN_PASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc new file mode 100644 index 00000000000..57a5fec527a --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct PropagateTensorFlowABIKnowledgePass + : public PropagateTensorFlowABIKnowledgePassBase< + PropagateTensorFlowABIKnowledgePass> { + explicit PropagateTensorFlowABIKnowledgePass( + mlir::FunctionType type, llvm::ArrayRef same_shape) { + func_type_ = type; + same_shape_ = same_shape; + } + + void runOnOperation() override { + // We know due to tensorflow ABI that the offset is always 0 and that the + // innermost stride is always 1. To make this visible to the compiler, + // we insert constants into the code and replace usages accordingly. + // We do not change the signature so that we keep a somewhat stable ABI + // that is easy to undertand by tools. + // We also know that tensorflow aligns all allocated pointers by 16, so + // we pass this on. Furthermore, we know that arguments never alias. More + // precicely, they may only alias (due to reuse) if the kernel does not + // read from a position it previously has written to. We express this with + // the noalias attribute. + mlir::LLVM::LLVMFuncOp func = getOperation(); + + // This only works if the function is local and we can rewrite it. + if (func.isExternal()) return; + + mlir::OpBuilder b(func.getBody()); + // Steal the LLVM representation of the index type from the third argument. + auto index_type = func.getArgument(3).getType(); + mlir::Value one = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); + mlir::Value zero = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); + uint32_t arg_pos = 0; + std::vector positions; + // Collect the agument and return types of the surrounding function. + auto arg_types = llvm::to_vector<4>(llvm::concat( + func_type_.getInputs(), func_type_.getResults())); + for (mlir::Type arg_type : arg_types) { + if (!arg_type.isa()) { + func.emitError() << "argument of surrounding func is not ranked memref"; + return signalPassFailure(); + } + positions.push_back(arg_pos); + // Set alignment and aliasing on the pointers. + func.setArgAttr(arg_pos + 1, "llvm.noalias", b.getBoolAttr(true)); + func.setArgAttr(arg_pos + 1, "llvm.align", b.getIndexAttr(16)); + // Replace the offset with zero. Offset is argument number 3. + func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); + // Forward over base_ptr, aligned_ptr, offset, size and stride arguments. + arg_pos += 3 + arg_type.cast().getRank() * 2; + // Replace the last stride with constant 1. + func.getArgument(arg_pos - 1).replaceAllUsesWith(one); + } + + // If we have knowledge that some arguments have the same shape, we + // can use that here. Simply replace usages of the shape parameters within + // the function body to a single shape parameter. + if (same_shape_.empty()) { + return; + } + auto first = same_shape_.front(); + auto first_offset = positions.at(first); + auto first_type = arg_types[first].cast(); + uint32_t rank = first_type.getRank(); + for (int i = 1, e = same_shape_.size(); i < e; ++i) { + uint32_t same = same_shape_[i]; + uint32_t same_offset = positions.at(same); + auto same_type = arg_types[same].cast(); + if (same_type.getRank() != rank) { + func.emitOpError() << "same shape constraints on arguments with " + "non-matching shapes: #" + << first << " and #" << same; + return signalPassFailure(); + } + + for (uint32_t i = 0; i < 2 * rank; ++i) { + // Replace uses for second arg data with first arg. + auto same_arg = func.getArgument(same_offset + 3 + i); + auto first_arg = func.getArgument(first_offset + 3 + i); + same_arg.replaceAllUsesWith(first_arg); + } + } + } +}; + +} // namespace + +std::unique_ptr> +CreatePropagateTensorFlowABIKnowledgePass(mlir::FunctionType type, + llvm::ArrayRef same_shape) { + return std::make_unique(type, + same_shape); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h index 257e84b4a21..0f2a41b3de6 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h @@ -20,6 +20,8 @@ limitations under the License. namespace mlir { +class BufferAssignmentPlacer; +class BufferAssignmentTypeConverter; class LLVMTypeConverter; class MLIRContext; class OwningRewritePatternList; @@ -37,6 +39,15 @@ void PopulateEmbedTFFrameworkConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns); } // namespace tf_framework + +namespace transforms { + +/// Collects a set of patterns that bufferize operations from the standard +/// dialect. +void populateStandardBufferizePattern(MLIRContext *context, + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns); +} // namespace transforms } // namespace kernel_gen } // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index 9c1b434b9b2..ab66c513e33 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -16,7 +16,6 @@ limitations under the License. // This file combines patterns for lowering shape dialect to standard ops, // structured control flow and descriptors. -#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project @@ -24,8 +23,8 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" namespace mlir { namespace kernel_gen { @@ -37,6 +36,10 @@ namespace { struct ShapeToDescriptorsPass : public ShapeToDescriptorsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { MLIRContext &ctx = getContext(); @@ -51,7 +54,6 @@ struct ShapeToDescriptorsPass OwningRewritePatternList patterns; populateShapeRewritePatterns(&ctx, patterns); populateShapeToStandardConversionPatterns(patterns, &ctx); - populateShapeToSCFConversionPatterns(patterns, &ctx); // Apply conversion. auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 2edcaabd7b4..3ce111ff3ff 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -101,6 +101,7 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { protected: StringRef GetFuncName() const override { return kCInterfaceAlloc; } + LLVMType GetFuncType() const override { LLVMType llvm_void_ptr_type = getVoidPtrType(); return LLVM::LLVMType::getFunctionTy( @@ -175,10 +176,23 @@ class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern { } }; +class NullContextOpConverter : public ConvertOpToLLVMPattern { + public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, getVoidPtrType()); + return success(); + } +}; + } // namespace void PopulateTFFrameworkToLLVMConversionPatterns( LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { + patterns->insert(*converter); patterns->insert(*converter); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc similarity index 63% rename from tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 916eedb55de..b2fcc424a50 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -13,25 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" namespace mlir { namespace kernel_gen { -namespace tf_framework { +namespace transforms { namespace { #define GEN_PASS_CLASSES #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" -class TestTFFrameworkToLLVMPass - : public TestTFFrameworkLegalizeToLLVMPassBase { +class TFKernelToLLVMPass : public TFKernelToLLVMPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); @@ -39,21 +44,25 @@ class TestTFFrameworkToLLVMPass // Populate type conversions. LLVMTypeConverter type_converter(m.getContext()); type_converter.addConversion([&](tf_framework::OpKernelContextType type) { - return LLVM::LLVMType::getInt8PtrTy(type_converter.getDialect()); + return LLVM::LLVMType::getInt8PtrTy(m.getContext()); }); // Populate patterns. OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(type_converter, patterns); - PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, &patterns); + tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, + &patterns); + populateGpuToLLVMConversionPatterns(type_converter, patterns, "gpu.binary"); + lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns); // Set target. ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); - target.addLegalOp(); + target + .addIllegalDialect(); + target.addIllegalOp(); - if (failed(applyFullConversion(m, target, patterns))) { + if (failed(applyPartialConversion(m, target, patterns))) { signalPassFailure(); } } @@ -61,11 +70,10 @@ class TestTFFrameworkToLLVMPass } // namespace -std::unique_ptr > -createTestTFFrameworkLegalizeToLLVMPass() { - return std::make_unique(); +std::unique_ptr > CreateTFKernelToLLVMPass() { + return std::make_unique(); } -} // namespace tf_framework +} // namespace transforms } // namespace kernel_gen } // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc new file mode 100644 index 00000000000..d2773d91b07 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct UnfuseBatchNormPass + : public UnfuseBatchNormPassBase { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +} // namespace + +std::unique_ptr CreateUnfuseBatchNormPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/array_container_utils.h b/tensorflow/compiler/mlir/utils/array_container_utils.h new file mode 100644 index 00000000000..c1a898185d9 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/array_container_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ + +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::MutableArrayRef SpanToMutableArrayRef(absl::Span span) { + return llvm::MutableArrayRef(span.data(), span.size()); +} + +template +inline absl::Span ArrayRefToSpan(llvm::ArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +template +inline absl::Span MutableArrayRefToSpan(llvm::MutableArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc new file mode 100644 index 00000000000..bc4e80f5aa1 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/name_utils.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/IR/Identifier.h" // from @llvm-project + +namespace mlir { + +namespace { +// Checks if a character is legal for a TensorFlow node name, with special +// handling if a character is at the beginning. +bool IsLegalChar(char c, bool first_char) { + if (isalpha(c)) return true; + if (isdigit(c)) return true; + if (c == '.') return true; + if (c == '_') return true; + + // First character of a node name can only be a letter, digit, dot or + // underscore. + if (first_char) return false; + + if (c == '/') return true; + if (c == '-') return true; + + return false; +} +} // anonymous namespace + +void LegalizeNodeName(std::string& name) { + if (name.empty()) return; + + if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.'; + + for (char& c : llvm::drop_begin(name, 1)) + if (!IsLegalChar(c, /*first_char=*/false)) c = '.'; +} + +std::string GetNameFromLoc(Location loc) { + llvm::SmallVector loc_names; + llvm::SmallVector locs; + locs.push_back(loc); + bool names_is_nonempty = false; + + while (!locs.empty()) { + Location curr_loc = locs.pop_back_val(); + + if (auto name_loc = curr_loc.dyn_cast()) { + // Add name in NameLoc. For NameLoc we also account for names due to ops + // in functions where the op's name is first. + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } else if (auto call_loc = curr_loc.dyn_cast()) { + // Add name if CallSiteLoc's callee has a NameLoc (as should be the + // case if imported with DebugInfo). + if (auto name_loc = call_loc.getCallee().dyn_cast()) { + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } + } else if (auto fused_loc = curr_loc.dyn_cast()) { + // Push all locations in FusedLoc in reverse order, so locations are + // visited based on order in FusedLoc. + auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); + locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); + continue; + } + + // Location is not a supported, so an empty StringRef is added. + loc_names.push_back(llvm::StringRef()); + } + + if (names_is_nonempty) + return llvm::join(loc_names.begin(), loc_names.end(), ";"); + + return ""; +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/name_utils.h b/tensorflow/compiler/mlir/utils/name_utils.h new file mode 100644 index 00000000000..4b08a41feec --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Location.h" // from @llvm-project + +namespace mlir { + +// Converts characters in name that are considered illegal in TensorFlow Node +// name to '.'. +void LegalizeNodeName(std::string& name); + +// Creates a TensorFlow node name from a location. +std::string GetNameFromLoc(Location loc); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/string_container_utils.h b/tensorflow/compiler/mlir/utils/string_container_utils.h new file mode 100644 index 00000000000..fb2fa06ca4d --- /dev/null +++ b/tensorflow/compiler/mlir/utils/string_container_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +inline absl::string_view StringRefToView(llvm::StringRef ref) { + return absl::string_view(ref.data(), ref.size()); +} + +inline llvm::StringRef StringViewToRef(absl::string_view view) { + return llvm::StringRef(view.data(), view.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index ada81634567..aa37181f9f0 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -15,6 +15,7 @@ package_group( "//learning/brain/experimental/mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/google/xla/mlir/...", + "//learning/deepmind/partir/...", "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", "//tensorflow/compiler/mlir/...", @@ -55,7 +56,9 @@ cc_library( "transforms/passes.h", ], deps = [ + ":attribute_importer", ":type_to_shape", + ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/mlir/hlo:convert_op_folder", @@ -68,7 +71,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", - "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:bfloat16", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:Dialect", @@ -94,6 +97,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_context", @@ -130,7 +134,6 @@ cc_library( ":hlo_utils", ":mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", @@ -235,8 +238,8 @@ cc_library( hdrs = ["mlir_hlo_to_hlo.h"], deps = [ ":type_to_shape", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/tf2xla:common", @@ -321,6 +324,16 @@ cc_library( ], ) +cc_library( + name = "translate_cl_options", + srcs = ["xla_mlir_translate_cl.cc"], + hdrs = ["xla_mlir_translate_cl.h"], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + cc_library( name = "xla_mlir_translate", srcs = ["xla_mlir_translate.cc"], @@ -329,8 +342,10 @@ cc_library( ":hlo_to_mlir_hlo", ":mhlo_to_lhlo_with_xla", ":mlir_hlo_to_hlo", + ":translate_cl_options", "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -339,6 +354,7 @@ cc_library( "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Translation", ], alwayslink = 1, @@ -385,14 +401,12 @@ cc_library( ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_control_flow", "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:legalize_to_standard", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index d366a36c212..a63fc12c285 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -521,6 +521,13 @@ StatusOr HloFunctionImporter::ImportInstruction( RandomDistributionToString(instruction->random_distribution()))); } } + case HloOpcode::kRngBitGenerator: { + auto rng_op = Cast(instruction); + auto op = func_builder->create( + loc, result_type, + func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]); + return op.getOperation(); + } case HloOpcode::kWhile: { auto op = func_builder->create( loc, operands[0].getType(), operands[0]); diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index db981bb0227..e0cc89004cf 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -62,7 +63,10 @@ class HloFunctionImporter { : context_(module.getContext()), module_(module), builder_(builder), - function_map_(function_map) {} + function_map_(function_map) { + context_->loadDialect(); + context_->loadDialect(); + } // Imports the given computation as a new function, if it hasn't been already // imported. diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index dd045da3899..9db5861934f 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -30,6 +30,12 @@ limitations under the License. namespace xla { +HloModuleImporter::HloModuleImporter(mlir::ModuleOp module) + : module_(module), builder_(module.getContext()) { + module.getContext()->loadDialect(); + module.getContext()->loadDialect(); +} + Status HloModuleImporter::Import(const xla::HloModule& module) { // TODO(hinsu): Only import the entry computation here once all HLO ops with // reference to other computation are updated to have a region instead of a diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 69ac1e28219..401299484ed 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -38,8 +38,7 @@ class Shape; // dialect. HloModuleImporter does not take ownership. class HloModuleImporter { public: - explicit HloModuleImporter(mlir::ModuleOp module) - : module_(module), builder_(module.getContext()) {} + explicit HloModuleImporter(mlir::ModuleOp module); // Import the HloModule into the MLIR Module. Status Import(const xla::HloModule& module); diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index cf78c81908d..b9d563a659d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -83,6 +83,9 @@ StatusOr> GetPermutationIfAvailable( strides[dim] = accumulated_stride; accumulated_stride *= shape.dimensions(dim); } + if (accumulated_stride == 0) { + return llvm::SmallVector{}; + } return llvm::SmallVector{ makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())}; } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index c94110d9102..ac5e01a0abf 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -312,6 +312,16 @@ StatusOr MlirHloBuilder::RngOpInternal( return CreateOp(op_name, shape, operands); } +StatusOr MlirHloBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + full_result_shape, builder_)); + auto op = builder_.create( + loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { @@ -351,6 +361,13 @@ StatusOr MlirHloBuilder::InDimBroadcast( return MakeXlaOp(op.getResult()); } +StatusOr MlirHloBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) { + return Unimplemented("MlirHloBuilder does not support op %s", + HloOpcodeString(opcode)); +} + StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { @@ -382,6 +399,31 @@ XlaOp MlirHloBuilder::CreateToken() { }); } +StatusOr MlirHloBuilder::TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + auto op = builder_.create( + loc_, result_ty, GetValue(a), GetValue(b), + builder_.getBoolAttr(options.left_side()), + builder_.getBoolAttr(options.lower()), + builder_.getBoolAttr(options.unit_diagonal()), + builder_.getStringAttr( + TriangularSolveOptions::Transpose_Name(options.transpose_a()))); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + auto op = builder_.create( + loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::InfeedWithTokenInternal( const Shape& infeed_instruction_shape, XlaOp token, const string& config) { TF_ASSIGN_OR_RETURN(mlir::Type result_type, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index a12eb723465..00b7aa4d0b0 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -124,6 +124,13 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; + StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, + TriangularSolveOptions options) override; + + StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) override; + StatusOr CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, @@ -176,6 +183,9 @@ class MlirHloBuilder : public XlaBuilder { StatusOr RngOpInternal(RandomDistribution distribution, absl::Span parameters, const Shape& shape) override; + StatusOr RngBitGeneratorInternal(const Shape& full_result_shape, + RandomAlgorithm algorithm, + XlaOp initial_state) override; StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; @@ -189,6 +199,9 @@ class MlirHloBuilder : public XlaBuilder { const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) override; + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) override; + StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) override; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index e6d0b8f8dd8..d6ef39d03dd 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -105,6 +106,9 @@ static mlir::LogicalResult GetXlaOp( // TODO(hpucha): This should be consolidated into a general place. static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } +static uint32_t Convertuint32_t(uint32_t i) { return i; } +static uint64_t Convertuint64_t(uint64_t i) { return i; } + // Convert APFloat to double. static double ConvertAPFloat(llvm::APFloat value) { const auto& semantics = value.getSemantics(); @@ -430,6 +434,27 @@ static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute( return frontend_attributes; } +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) { + xla::OpMetadata metadata; + if (op->getLoc().isa()) return metadata; + + std::string name = mlir::GetNameFromLoc(op->getLoc()); + mlir::LegalizeNodeName(name); + metadata.set_op_name(name); + + if (auto file_line_col_loc = op->getLoc().dyn_cast()) { + metadata.set_source_file(file_line_col_loc.getFilename().str()); + metadata.set_source_line(file_line_col_loc.getLine()); + } + + return metadata; +} + // Checks if all shardings are set. static bool AllOptionalShardingsAreSet( llvm::ArrayRef> shardings) { @@ -761,7 +786,7 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()), - op.iota_dimension().getSExtValue()); + op.iota_dimension()); return success(); } @@ -882,6 +907,17 @@ LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()]; + auto xla_result = xla::RngBitGenerator( + static_cast(op.rng_algorithm()), Unwrap(xla_arg_1), + xla::TypeToShape(result.getType()).tuple_shapes(1)); + value_map[result] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp mu, sigma; @@ -974,7 +1010,7 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator, - op.dimension().getSExtValue(), op.is_stable()); + op.dimension(), op.is_stable()); return success(); } diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 407a7d3da38..801c04496f0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -165,6 +165,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { "frontend_attributes(lowering_context.builder, " "CreateOpFrontendAttributesFromAttribute(op));\n\n"; + // Create a scoped object to assign op metadata to generated XLA ops. + os << " xla::XlaScopedOpMetadataAssignment " + "op_metadata(lowering_context.builder, " + "CreateOpMetadataFromLocation(op));\n\n"; + // Retrieve all the definitions derived from HLO_Op and sort by record name. for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { // Skip operations that have a custom exporter. diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt index 3630d2d45e4..a83e36cff64 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt @@ -8,6 +8,6 @@ HloModule TestModule ENTRY TestComputation { x = f32[3, 2]{1,0} parameter(0) - // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () ROOT x.copy = f32[3, 2]{0,1} copy(x) } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index 69eaeeb946d..5a07d9303f0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -17,9 +17,7 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] // CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> // CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] -// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] -// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> -// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> // CHECK: return [[RESULT]] : tensor<3x4x4xf32> // CHECK: } @@ -29,7 +27,6 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_lhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -43,7 +40,6 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_rhs_batch // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, @@ -64,20 +60,20 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) return %0 : tensor } -func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { +func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { // CHECK-LABEL: func @batchmatmulv2_adj_real // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32> return %0 : tensor<5x4xf32> } -func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex>, %arg1: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK-LABEL: func @batchmatmulv2_adj_complex( -// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex>, [[RHS:%.*]]: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex>, [[RHS:%.*]]: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]]) // CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]]) // CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]]) @@ -88,6 +84,6 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]]) // CHECK: shape.shape_of [[LHSCONJ]] // CHECK: shape.shape_of [[RHSCONJ]] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex>, tensor<4x2xcomplex>) -> tensor<5x4xcomplex> return %0 : tensor<5x4xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index fd9c14c7c0f..887fdea5a21 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -2,6 +2,7 @@ // (unlike the rest), since this is the primary use case for such ops and // verification of shapes and broadcasts is desired. // RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck --check-prefix CHLO %s //===----------------------------------------------------------------------===// // Binary op legalizations. @@ -48,8 +49,8 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, tensor -> tensor + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<2xindex> // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor @@ -58,6 +59,15 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor } +// CHECK-LABEL: func @broadcast_add_unranked +// CHLO-LABEL: func @broadcast_add_unranked +func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: tf.Add + // CHLO: chlo.broadcast_add %arg0, %arg1 + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> @@ -139,9 +149,9 @@ func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8 } // CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { +func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -153,9 +163,9 @@ func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { } // CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { +func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK-NEXT: mhlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -201,8 +211,8 @@ func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor // NOT-CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] // NOT-CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor) { // NOT-CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 - // NOT-CHECK-NEXT: %[[RESULT_SHAPE:.+]] = shape.broadcast %[[LHS_SHAPE1]], %[[RHS_SHAPE]] - // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // NOT-CHECK-NEXT: %[[RESULT_SHAPE:.+]] = shape.broadcast %[[LHS_SHAPE1]], %[[RHS_SHAPE]] : tensor, tensor -> tensor + // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<1xindex> // NOT-CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // NOT-CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // NOT-CHECK-NEXT: %[[RESULT:.+]] = "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} @@ -290,8 +300,8 @@ func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, tensor -> tensor + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<1xindex> // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir index f84a2f28a23..876a1bf03e7 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir @@ -169,7 +169,7 @@ func @send_to_host(%arg0: tensor) { // CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key"} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key_dtoh_0"} // CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.token "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor) -> () return @@ -186,7 +186,7 @@ func @recv_from_host() -> tensor { // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[INIT_TOKEN]]) // CHECK-SAME: channel_id = {handle = 1 : i64, type = 3 : i64} // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key"} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key_htod_0"} // CHECK-SAME: (!mhlo.token) -> tuple, !mhlo.token> @@ -407,6 +407,694 @@ func @callee2() attributes {sym_visibility = "private"} { // ----- +// Test cloned function rewrite also checks transitive function calls to +// TF/XLA communication ops. + +// CHECK: func @callee3() +func @callee3() { + // CHECK: [[CALLEE3_INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: call @callee4{{.+}}([[CALLEE3_INIT_TOKEN]]) + call @callee4() : () -> () + return +} + +// CHECK: func @callee4() +func @callee4() { + // CHECK: [[CALLEE4_INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[CALL_5:%.*]] = call @callee5([[CALLEE4_INIT_TOKEN]]) + call @callee5() : () -> () + + // CHECK: return + return +} + +// CHECK: func @callee5([[CALLEE5_ARG0:%.*]]: !mhlo.token) -> !mhlo.token +func @callee5() attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[CALLEE5_ARG0]]) + // CHECK: [[RECV_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: [[RECV_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + + // CHECK: return [[RECV_TOKEN]] + return +} + +// CHECK: func @callee4{{.+}}([[CALLEE4_ARG0:%.*]]: !mhlo.token) -> !mhlo.token attributes {sym_visibility = "private"} +// CHECK-NOT: "mhlo.create_token" +// CHECK: [[CALL_5:%.*]] = call @callee5([[CALLEE4_ARG0]]) +// CHECK: return [[CALL_5]] + +// ----- + +// Tests `mhlo.if` with branches populated with TF/XLA communication ops. + +// CHECK-LABEL: func @if_both_branches +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_VALUE]], [[TRUE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} + + // CHECK: [[TRUE_RECV_TUPLE:%.*]] = "mhlo.recv"([[TRUE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_true", send_key = "send_if_true", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + + // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[FALSE_REGION_ARG_VALUE]], [[FALSE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} + + // CHECK: [[FALSE_RECV_TUPLE:%.*]] = "mhlo.recv"([[FALSE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_false", send_key = "send_if_false", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with only the `true` branch populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @if_true_branch +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_true_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_VALUE]], [[TRUE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} + + // CHECK: [[TRUE_RECV_TUPLE:%.*]] = "mhlo.recv"([[TRUE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_true", send_key = "send_if_true", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%arg3) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with only the `false` branch populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @if_false_branch +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_false_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg3) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + + // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[FALSE_REGION_ARG_VALUE]], [[FALSE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} + + // CHECK: [[FALSE_RECV_TUPLE:%.*]] = "mhlo.recv"([[FALSE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_false", send_key = "send_if_false", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with tuple arg from a `mhlo.tuple` only used by `mhlo.if` is +// replaced. + +// CHECK-LABEL: func @if_replace_tuple_arg +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_replace_tuple_arg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NOT: "mhlo.tuple"([[ARG1]], [[ARG2]]) + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[ARG2]], [[INIT_TOKEN]]) + %0 = "mhlo.tuple"(%arg1, %arg2) : (tensor, tensor) -> tuple, tensor> + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[IF_ARG_TUPLE]], [[IF_ARG_TUPLE]]) + %1 = "mhlo.if"(%arg0, %0, %0) ( { + ^bb0(%arg3: tuple, tensor>): + %2 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%2) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tuple, tensor>): + %2 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple, tensor>) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// Tests `mhlo.if` with tuple arg not from a `mhlo.tuple` is unpacked. + +// CHECK-LABEL: func @if_unpack_tuple_arg +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tuple, tensor>) +func @if_unpack_tuple_arg(%arg0: tensor, %arg1: tuple, tensor>) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK-DAG: [[IF_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[ARG1]]) {index = 0 + // CHECK-DAG: [[IF_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[ARG1]]) {index = 1 + // CHECK: [[IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[IF_ARG_ELEMENT0]], [[IF_ARG_ELEMENT1]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[IF_ARG_TUPLE]], [[IF_ARG_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%1) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` tuple result is extended with a `mhlo.token`. + +// CHECK-LABEL: func @if_extend_tuple_result +func @if_extend_tuple_result(%arg0: tensor, %arg1: tuple, tensor>) -> tuple, tensor> { + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%1) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%arg2) : (tuple, tensor>) -> () + }, { + ^bb0(%arg2: tuple, tensor>): + "mhlo.return"(%arg2) : (tuple, tensor>) -> () + // CHECK: (tensor, tuple, tensor, !mhlo.token>, tuple, tensor, !mhlo.token>) -> tuple, tensor, !mhlo.token> + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tuple, tensor> + + // CHECK-DAG: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 0 + // CHECK-DAG: [[IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 1 + // CHECK: [[IF_SUBTUPLE_RESULT:%.*]] = "mhlo.tuple"([[IF_TUPLE_ELEMENT0]], [[IF_TUPLE_ELEMENT1]]) + // CHECK: return [[IF_SUBTUPLE_RESULT]] + return %0 : tuple, tensor> +} + +// ----- + +// Tests nested `mhlo.if` containing TF/XLA communication ops. + +// CHECK-LABEL: func @if_nested +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) +func @if_nested(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[OUTER_IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + + // CHECK: "mhlo.if"([[ARG0]], [[OUTER_IF_ARG_TUPLE]], [[OUTER_IF_ARG_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK-NEXT: ^bb0([[OUTER_IF_TRUE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[OUTER_IF_TRUE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_TRUE_ARG]]) {index = 0 + // CHECK-DAG: [[OUTER_IF_TRUE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_TRUE_ARG]]) {index = 1 + // CHECK: [[INNER_IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[OUTER_IF_TRUE_ARG_ELEMENT0]], [[OUTER_IF_TRUE_ARG_ELEMENT1]]) + + %1 = mhlo.constant dense : tensor + + // CHECK: [[INNER_IF_TUPLE:%.*]] = "mhlo.if"({{%.*}}, [[INNER_IF_ARG_TUPLE]], [[INNER_IF_ARG_TUPLE]]) + %2 = "mhlo.if"(%1, %arg2, %arg2) ( { + // CHECK-NEXT: ^bb0([[INNER_IF_TRUE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[INNER_IF_TRUE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TRUE_ARG]]) {index = 0 + // CHECK-DAG: [[INNER_IF_TRUE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TRUE_ARG]]) {index = 1 + + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send"([[INNER_IF_TRUE_ARG_ELEMENT0]], [[INNER_IF_TRUE_ARG_ELEMENT1]]) + "tf.XlaSendToHost"(%arg3) {key = "send_key"} : (tensor) -> () + + // CHECK: [[INNER_IF_TRUE_RESULT:%.*]] = "mhlo.tuple"([[INNER_IF_TRUE_ARG_ELEMENT0]], [[SEND_TOKEN]]) + // CHECK: "mhlo.return"([[INNER_IF_TRUE_RESULT]]) + "mhlo.return"(%arg3) : (tensor) -> () + + // CHECK-NEXT: }, { + }, { + + // CHECK-NEXT: ^bb0([[INNER_IF_FALSE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[INNER_IF_FALSE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_FALSE_ARG]]) {index = 0 + // CHECK-DAG: [[INNER_IF_FALSE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_FALSE_ARG]]) {index = 1 + // CHECK: [[INNER_IF_FALSE_RESULT:%.*]] = "mhlo.tuple"([[INNER_IF_FALSE_ARG_ELEMENT0]], [[INNER_IF_FALSE_ARG_ELEMENT1]]) + // CHECK: "mhlo.return"([[INNER_IF_FALSE_RESULT]]) + "mhlo.return"(%arg3) : (tensor) -> () + // CHECK-NEXT: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[INNER_IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TUPLE]]) {index = 1 + // CHECK: [[OUTER_IF_TRUE_RESULT:%.*]] = "mhlo.tuple"([[OUTER_IF_TRUE_ARG_ELEMENT0]], [[INNER_IF_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[OUTER_IF_TRUE_RESULT]]) + "mhlo.return"(%arg2) : (tensor) -> () + + // CHECK-NEXT: }, { + }, { + + // CHECK-NEXT: ^bb0([[OUTER_IF_FALSE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[OUTER_IF_FALSE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_FALSE_ARG]]) {index = 0 + // CHECK-DAG: [[OUTER_IF_FALSE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_FALSE_ARG]]) {index = 1 + // CHECK: [[OUTER_IF_FALSE_RESULT:%.*]] = "mhlo.tuple"([[OUTER_IF_FALSE_ARG_ELEMENT0]], [[OUTER_IF_FALSE_ARG_ELEMENT1]]) + // CHECK: "mhlo.return"([[OUTER_IF_FALSE_RESULT]]) + "mhlo.return"(%arg2) : (tensor) -> () + // CHECK-NEXT: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` containing a function call to TF/XLA communication ops. + +// CHECK-LABEL: func @if_function_call +func @if_function_call(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + // CHECK: [[CALL_TOKEN:%.*]] = call @callee([[TRUE_REGION_ARG_ELEMENT0]], [[TRUE_REGION_ARG_ELEMENT1]]) + call @callee(%arg2) : (tensor) -> () + + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_REGION_ARG_ELEMENT0]], [[CALL_TOKEN]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @callee +// CHECK-SAME: ([[CALLEE_ARG0:%.*]]: tensor, [[CALLEE_ARG1:%.*]]: !mhlo.token) -> !mhlo.token +func @callee(%arg0: tensor) attributes {sym_visibility = "private"} { + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send" + "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor) -> () + + // CHECK: return [[SEND_TOKEN]] + return +} + +// ----- + +// Tests `mhlo.if` containing multiple TF/XLA communication ops. + +// CHECK-LABEL: func @if_region_multiple_ops +func @if_region_multiple_ops(%arg0: tensor, %arg1: tensor) { + // CHECK: "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK: [[TRUE_REGION_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK: [[TRUE_REGION_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[SEND0_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_ELEMENT0]], [[TRUE_REGION_ARG_ELEMENT1]]) + "tf.XlaSendToHost"(%arg2) {key = "send_key0"} : (tensor) -> () + + // CHECK: [[SEND1_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_ELEMENT0]], [[SEND0_TOKEN]]) + "tf.XlaSendToHost"(%arg2) {key = "send_key1"} : (tensor) -> () + + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_REGION_ARG_ELEMENT0]], [[SEND1_TOKEN]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +// Tests `mhlo.if` containing TF/XLA communication ops followed by other TF/XLA +// communication ops. + +func @if_followed_by_communication_op(%arg0: tensor, %arg1: tensor) { + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tensor): + "tf.XlaSendToHost"(%arg2) {key = "send_key0"} : (tensor) -> () + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 1 + + // CHECK: "mhlo.send"({{.*}}, [[IF_TUPLE_ELEMENT1]]) + "tf.XlaSendToHost"(%arg1) {key = "send_key1"} : (tensor) -> () + return +} + +// ----- + +// Tests `mhlo.while` with cond and body populated with TF/XLA communication +// ops. + +// CHECK-LABEL: func @while_cond_body +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_cond_body(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[COND_REGION_ARG_VALUE]], [[COND_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} + + // CHECK: [[COND_RECV_TUPLE:%.*]] = "mhlo.recv"([[COND_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + + // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[BODY_REGION_ARG_VALUE]], [[BODY_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} + + // CHECK: [[BODY_RECV_TUPLE:%.*]] = "mhlo.recv"([[BODY_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` with only the `cond` region populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @while_cond +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_cond(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[COND_REGION_ARG_VALUE]], [[COND_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} + + // CHECK: [[COND_RECV_TUPLE:%.*]] = "mhlo.recv"([[COND_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%arg1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` with only the `body` region populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @while_body +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_body(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + + // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[BODY_REGION_ARG_VALUE]], [[BODY_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} + + // CHECK: [[BODY_RECV_TUPLE:%.*]] = "mhlo.recv"([[BODY_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` containing TF/XLA communication ops followed by other +// TF/XLA communication ops. + +func @while_followed_by_communication_op(%arg0: tensor) { + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while" + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + "tf.XlaSendToHost"(%arg1) {key = "send_key0"} : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) {index = 1 + + // CHECK: "mhlo.send"({{.*}}, [[WHILE_TUPLE_ELEMENT1]]) + "tf.XlaSendToHost"(%arg0) {key = "send_key1"} : (tensor) -> () + return +} + +// ----- + +// Tests unsupported parent of TF/XLA communication op. + +func @unsupported_ancestor(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + // expected-error@+1 {{expects ancestor(s) to be of ['mhlo.if', 'func']}} + "tf._XlaHostComputeMlir"() {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : () -> () + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return +} + +// ----- + +// Tests transitive unsupported parent of TF/XLA communication op. + +func @unsupported_ancestor(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + // expected-error@+1 {{expects ancestor(s) to be of ['mhlo.if', 'func']}} + call @callee() : () -> () + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return +} + +func @callee() attributes {sym_visibility = "private"} { + "tf._XlaHostComputeMlir"() {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : () -> () + return +} + +// ----- + +// Tests unsupported `mhlo.if` with region of more than one block and contains a +// TF/XLA communication op. + +func @if_multiple_blocks(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tensor): + br ^bb1(%arg2 : tensor) + ^bb1(%arg3: tensor): + // expected-error@+1 {{expects single block region ancestor(s)}} + "tf.XlaSendToHost"(%arg3) {key = "send_key0"} : (tensor) -> () + "mhlo.return"(%arg3) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + // Tests function with more than one block that is to be rewritten emits an // error instead. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 5a9089756a9..93eac3821b2 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -44,7 +44,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK-LABEL: func @case // CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor, tensor, tensor) -> (tensor, tensor) + %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> tuple, tensor> // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir new file mode 100644 index 00000000000..9f72820d15b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir @@ -0,0 +1,50 @@ +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s +// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s + +// We run this test four times: +// 1) Legalize without using TF2XLA fallback (ops cannot be legalized). +// 2) Use fallback with a device that supports all ops (ops can be legalized). +// 3) Use fallback with unspecified device (ops cannot be legalized). +// 4) Use fallback with specified but unsupported device (ops cannot be legalized). +// +// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases +// produce remarks that don't occur for 1) and 2) and there is no way to check +// the remarks only for 3) and 4) (except using two files). + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + +// CHECK-LABEL: non_max_suppression_v4 +func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = mhlo.constant dense<2> : tensor + // NO_FALLBACK: tf.NonMaxSuppressionV4 + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4 + // UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4 + // UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4 + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0#0 : tensor<2xi32> +} + +// CHECK-LABEL: mirror_pad +func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { + %0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> + // NO_FALLBACK: tf.MirrorPad + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad + // UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad + // UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad + %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex>, tensor<2x2xi32>) -> tensor<4x7xcomplex> + return %1 : tensor<4x7xcomplex> +} + +// CHECK-LABEL: atan2 +func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { + // NO_FALLBACK: tf.Atan2 + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2 + // UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2 + // UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2 + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32> + return %0: tensor<4x4x4xf32> +} + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index cd351447303..8c8d99940de 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -220,13 +220,6 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso return %0 : tensor<3x3xf32> } -// CHECK-LABEL: fft -func @fft(%arg0: tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> { - // CHECK: "mhlo.fft"(%arg0) - %0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> - return %0 : tensor<3x5x8xcomplex> -} - // CHECK-LABEL: reverse_sequence func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> { // CHECK-NOT: tf.ReverseSequence @@ -265,6 +258,47 @@ func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2 return %0#0 : tensor<2xi32> } +// CHECK-LABEL: bessel_i0e +func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI0e + %0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> +} + +// CHECK-LABEL: bessel_i1e +func @bessel_i1e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI1e + %0 = "tf.BesselI1e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI1e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI1e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> +} + +// CHECK-LABEL: diag +func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> { + // CHECK-NOT: tf.Diag + %0 = "tf.Diag"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: random_uniform_int +func @random_uniform_int(%arg0: tensor, %arg1: tensor) -> tensor<1000xi32> { + %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.RandomUniformInt + %1 = "tf.RandomUniformInt"(%0, %arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor, tensor) -> tensor<1000xi32> + return %1 : tensor<1000xi32> +} + +// CHECK-LABEL: multinomial +func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor, %seed2: tensor) -> tensor<2x10xi32> { + // CHECK-NOT: tf.Multinomial + %samples = "tf.Const"() { value = dense<10> : tensor } : () -> tensor + %1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor) -> tensor<2x10xi32> + return %1 : tensor<2x10xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 3b4efc388eb..4c5ce2f74d9 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,5 +1,5 @@ // RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all // This test runs twice: // 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying // that the chlo ops emit produces more useful tests. @@ -439,6 +439,17 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // Bias op legalizations. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @biasAdd_default +func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + return %0 : tensor<1x32x10x32xi32> +} + // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 @@ -1269,6 +1280,15 @@ func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x return %0 : tensor<2x8x4x7x7xf32> } +// CHECK-LABEL: maxpool_explicit_padding +func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: tf.MaxPool + // TODO(b/165938852): need to support explicit padding in max_pool. + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + return %0 : tensor<2x3x5x7xi32> +} + //===----------------------------------------------------------------------===// // MaxPoolGrad op legalizations. //===----------------------------------------------------------------------===// @@ -1499,6 +1519,35 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te return %arg1, %arg0 : tensor, tensor } +//===----------------------------------------------------------------------===// +// Elu op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @elu +func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %arg0, %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} + // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0) + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) + // CHECK: return %[[RESULT]] + %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + return %0: tensor<1xf32> +} + +// CHECK-LABEL: func @elu_grad +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} + // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]]) + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]]) + // CHECK: return %[[RESULT]] + %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} + //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// @@ -1726,6 +1775,20 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Fast Fourier Transform op legalization. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @fft_1D +func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex> + %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: func @ifft_1D +func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex> + %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + // CHECK-LABEL: func @rfft_1D func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) @@ -1734,6 +1797,48 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { return %0 : tensor<8xcomplex> } +// CHECK-LABEL: func @rfft_1D_padded +func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %2) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor) -> tensor<8xf32> + // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: func @rfft_1D_sliced +func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> + // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32> + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex> + return %0 : tensor<2x8xcomplex> +} + +// CHECK-LABEL: func @irfft_1D +func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<5xf32> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex>) -> tensor<5xcomplex> + // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex> + %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// CHECK-LABEL: fft_1D_dynamic +func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { + // CHECK: "tf.FFT" + %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + +// CHECK-LABEL: rfft_1D_dynamic +func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "tf.RFFT" + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> +} + //===----------------------------------------------------------------------===// // Shape op legalization. //===----------------------------------------------------------------------===// @@ -1852,16 +1957,16 @@ func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @acos // CHLO-LABEL: @acos func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "chlo.acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: chlo.acos %arg0 : tensor<2xf32> // CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} -// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> -// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> // CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 +// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> // CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] // CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) // CHLO: %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00> // CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 // CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] +// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> // CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] // CHLO: %[[VAL_12:.*]] = mhlo.constant dense<3.14159274> // CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) @@ -1870,6 +1975,44 @@ func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: @acos_dynamic +// CHLO-LABEL: @acos_dynamic +func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: chlo.acos %arg0 : tensor<*xf32> + // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be + // lowered further on ranked tensors. Unranked CHLO must be transformed to + // ranked code before further lowering. + // CHLO: "tf.Acos" + %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> +// CHLO-LABEL: @tan +// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> +func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { + // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> + // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) + // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) + // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) + %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> + return %result : tensor<2xf32> +} + +// CHECK-LABEL: @tan_unranked +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHLO-LABEL: @tan_unranked +// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> + // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) + // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) + // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) + %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + return %result : tensor<*xf32> +} + // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: "mhlo.convert"(%arg0) : (tensor) -> tensor @@ -2266,10 +2409,10 @@ func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { } // CHECK-LABEL: reshape_dynamic -func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { - // CHECK: "mhlo.reshape" - %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> - return %0 : tensor<1x1xf32> +func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { + // CHECK: "mhlo.dynamic_reshape" + %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor + return %0 : tensor } // CHECK-LABEL: reshape_unranked @@ -2300,6 +2443,25 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> } +// CHECK-LABEL: expand_dims_dynamic +func @expand_dims_dynamic(%arg0: tensor) -> tensor { + %axis = "tf.Const"() {value = dense<1> : tensor} : () -> (tensor) + + // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0 + // CHECK-DAG: [[CST0:%.+]] = constant 0 + // CHECK-DAG: [[CST1:%.+]] = constant 1 + // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]] + // CHECK-DAG: [[CST1_0:%.+]] = constant 1 + // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]] + // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]] + // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]] + // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]]) + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor) -> tensor + + // CHECK: return [[RESHAPE]] + return %0 : tensor +} + // CHECK-LABEL: func @sign // CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { @@ -3463,6 +3625,20 @@ func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tenso return %result : tensor<2x8x8x8x1xf32> } +// CHECK-LABEL: @collective_permute +func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %source_target_pairs = "tf.Const" () { + value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> + } : () -> tensor<3x2xi32> + + // CHECK: "mhlo.collective_permute" + // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { + } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> + + return %0 : tensor<128x32xf32> +} + // CHECK-LABEL: @cross_replica_sum func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { %replica_groups = "tf.Const" () { @@ -3483,8 +3659,9 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { func @size_scalar_i32(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor + // CHECK: %[[CAST:.*]] = tensor_cast %[[CONST]] : tensor to tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor) -> tensor - // CHECK: return %[[CONST]] + // CHECK: return %[[CAST]] return %size : tensor } @@ -3492,8 +3669,9 @@ func @size_scalar_i32(%input: tensor) -> (tensor) { func @size_scalar_i64(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor + // CHECK: %[[CAST:.*]] = tensor_cast %[[CONST]] : tensor to tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor) -> tensor - // CHECK: return %[[CONST]] + // CHECK: return %[[CAST]] return %size : tensor } @@ -3754,7 +3932,7 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<0x7F800000> : tensor + // CHECK: mhlo.constant dense<3.40282347E+38> : tensor // CHECK: mhlo.scatter // CHECK: mhlo.minimum %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) @@ -3764,7 +3942,7 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<0xFF800000> : tensor + // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor // CHECK: mhlo.scatter // CHECK: mhlo.maximum %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) @@ -4581,21 +4759,65 @@ func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { } // CHECK-LABEL: func @cumsum_exclusive +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: "tf.Cumsum" + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[CONVERT_REDUCE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> return %1 : tensor<4xf32> } // CHECK-LABEL: func @cumsum_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: "tf.Cumsum" + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> return %1 : tensor<4xf32> } +// CHECK-LABEL: func @cumsum_exclusive_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + // CHECK-LABEL: func @cumsum_dynamic func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "tf.Cumsum" @@ -4603,6 +4825,24 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor return %0 : tensor } +//===----------------------------------------------------------------------===// +// Cumprod op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @cumprod +func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { + // CHECK: mhlo.mul + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +//===----------------------------------------------------------------------===// +// Qr op legalization +//===----------------------------------------------------------------------===// + // CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { // The tf.Qr lowering is a full algorithm that is not effective to verify with @@ -4697,3 +4937,37 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> return %0 : tensor<8x16xf64> } + +// CHECK-LABEL: @xla_gather +func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32> + return %0 : tensor<10x1x300xf32> +} + +// CHECK-LABEL: @xla_gather_i32 +func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32> + return %0 : tensor<10x1x300xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc index 1a3f0c16247..de8d6fc697b 100644 --- a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -42,13 +42,13 @@ class XlaBuilderTest : public ::testing::Test { protected: XlaBuilderTest() : name_(SetupTest()), - context_(), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))), builder_(&module_->getBodyRegion()), - xla_builder_(name_, builder_, module_->getLoc()) {} + xla_builder_(name_, builder_, module_->getLoc()) { + context_.loadDialect(); + } string SetupTest() { - mlir::registerDialect(); return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 9929bd85b43..ff1bcadda7b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -362,7 +362,9 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]), custom_call_target="foo", backend_config="bar" +// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: custom_call_target="foo" +// CHECK-SAME: backend_config="bar" // ----- @@ -1087,3 +1089,15 @@ func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token { } // CHECK-NOT: frontend_attributes + +// ----- + +// Checks exporting rng-bit-generator. + +// CHECK: HloModule +func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { +// CHECK: %[[ARG0:.*]] = u64[3] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %[[ARG0]]), algorithm=rng_philox + %0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> + return %0 : tuple, tensor<2x2xui32>> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir index 97c53cb5f9f..0c2aee5a2fd 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir @@ -2,6 +2,6 @@ // CHECK: Opaque elements attr not supported func @main() { - %0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> + %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> return } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index d89b1fa44e1..4d4e0213da8 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1005,3 +1005,12 @@ add { // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } + +// CHECK-LABEL: func @rngbitgen +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) +%rngbitgen (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { + %Arg_0.1 = u64[3] parameter(0) + // CHECK: "mhlo.rng_bit_generator"(%[[ARG0]]) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> + ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir new file mode 100644 index 00000000000..2182ce6106d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(unknown) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-NOT: metadata + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("AfterAll") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="AfterAll"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("name@function") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="name"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("file_name":2:8) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={source_file="file_name" source_line=2} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 0b420fff785..c990473a6d4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -15,6 +15,7 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. +#include #include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" +#include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -57,7 +60,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -68,12 +71,22 @@ namespace { constexpr char kShardingAttr[] = "mhlo.sharding"; class LegalizeTF : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { allow_partial_conversion_ = allow_partial_conversion; legalize_chlo_ = legalize_chlo; + use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue(); + if (tf2xla_fallback_device_type.hasValue()) { + device_type_ = tf2xla_fallback_device_type.getValue().str(); + } } /// Performs the lowering to XLA dialect. @@ -89,15 +102,26 @@ class LegalizeTF : public PassWrapper { llvm::cl::desc( "Also legalizes intermediate chlo ops to hlo (default true)"), llvm::cl::init(true)}; + Option use_tf2xla_fallback_{ + *this, "use-tf2xla-fallback", + llvm::cl::desc( + "Also use TF2XLA fallback for legalization (default false)"), + llvm::cl::init(false)}; + Option device_type_{ + *this, "device-type", + llvm::cl::desc( + "The device type used by TF2XLA fallback. Must be specified if " + "use-tf2xla-fallback is true, otherwise not used."), + llvm::cl::init("INVALID_DEVICE_TYPE")}; }; /// Returns if the given TF data format string is the default format. static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; } /// Returns the feature dimension for the given format and input type. -static size_t GetFeatureDimension(StringAttr format, +static size_t GetFeatureDimension(StringRef format, RankedTensorType inputType) { - return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; + return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1; } // Gets all integer values from the given attribute and push them to `values`. @@ -246,49 +270,21 @@ tensorflow::TensorShape ToTensorShape( sizes.begin(), sizes.end())); } -// Returns minimal value for the given int or float element type. -static ConstOp GetMinValueForType(Type ty, Location loc, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat neg_inf = - APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true); - attr = DenseElementsAttr::get(scalar_ty, neg_inf); - } else { - auto int_ty = ty.cast(); - APInt min_val = APInt::getSignedMinValue(int_ty.getWidth()); - attr = DenseElementsAttr::get(scalar_ty, min_val); - } - return rewriter->create(loc, attr); -} - -// Returns maximal value for the given int or float element type. -static ConstOp GetMaxValueForType(Type ty, Location loc, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat pos_inf = - APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false); - attr = DenseElementsAttr::get(scalar_ty, pos_inf); - } else { - auto int_ty = ty.cast(); - APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth()); - attr = DenseElementsAttr::get(scalar_ty, max_val); - } - return rewriter->create(loc, attr); -} - -// Returns int or float scalar DenseElementsAttr attribute with the given -// element type and the value. +// Returns int, float, or complex scalar DenseElementsAttr attribute with the +// given element type and the value. static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, OpBuilder *builder) { return builder->create(loc, hlo::GetScalarOfType(ty, raw_value)); } +// Returns a limit scalar const op for the given type. +// Requires FloatType or IntegerType +static ConstOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create(loc, hlo::GetScalarLimitOfType(ty, limit)); +} + // Creates an mhlo::SliceOp where the major dimensions have full size, and // the minor dimensions have the provided offsets and sizes. static Value SliceInMinorDims(Location loc, Value v, @@ -735,12 +731,33 @@ static void CreateWhile32(Location loc, int num_iterations, // BatchNorm op utilities. //===----------------------------------------------------------------------===// -static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, +static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format, Value input) { return b.getI64IntegerAttr( GetFeatureDimension(format, input.getType().cast())); } +//===----------------------------------------------------------------------===// +// FFT op utilities. +//===----------------------------------------------------------------------===// +// Returns the 1D i64 elements attribute populated with the inner-most dim of +// the value. +static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { + if (type.getRank() == 0) { + return builder->getI64TensorAttr({}); + } + return builder->getI64TensorAttr(type.getShape().back()); +} + +// Returns True if the inner-most dim is static. +bool CheckInnerDimStatic(ShapedType type, Builder *builder) { + if (!type.hasRank()) { + return false; + } + return !type.isDynamicDim(type.getShape().size() - 1); +} + //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -1049,6 +1066,21 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, builder->create(loc, compare); } +//===----------------------------------------------------------------------===// +// XlaGather op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidGatherDims(StringAttr attr) { + ::xla::GatherDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) { + ::xla::GatherDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertGatherDimensionNumbers(dims, builder); +} + //===----------------------------------------------------------------------===// // Op converters. //===----------------------------------------------------------------------===// @@ -1096,7 +1128,7 @@ class ConvertBiasAddOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto feature_dim = GetFeatureDimension( - op.data_formatAttr(), op.value().getType().cast()); + op.data_format(), op.value().getType().cast()); auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), feature_dim, rewriter); rewriter.replaceOpWithNewOp(op, op.value(), bias_broadcast); @@ -1675,6 +1707,80 @@ class ConvertEinsumOp : public OpRewritePattern { } }; +template +class ConvertFFTOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto input_ty = op.input().getType().template cast(); + if (!input_ty.hasRank()) { + return failure(); + } + auto input_shape = input_ty.getShape(); + DenseIntElementsAttr fft_length_attr; + if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) { + return failure(); + } + int64_t fft_length; + if (fft_length_attr.getNumElements() != 0) { + fft_length = fft_length_attr.getValue(0).getInt(); + } else { + return failure(); + } + + std::string fft_string = "RFFT"; + if (typeid(OpTy) == typeid(TF::IRFFTOp)) { + fft_length = fft_length / 2 + 1; + fft_string = "IRFFT"; + } + auto loc = op.getLoc(); + + // The inner-most dim cannot be dynamic. + if (input_ty.isDynamicDim(input_shape.size() - 1)) { + return failure(); + } + + auto expected_shape = llvm::to_vector<4>(input_shape.drop_back()); + expected_shape.push_back(fft_length); + + // Zero pad or truncate the last axis + Value reshaped = op.input(); + SmallVector begin_indices(input_shape.size(), 0); + SmallVector strides(input_shape.size(), 1); + + // Last dim larger than fft_length, slice the input + if (input_shape.back() > fft_length) { + reshaped = rewriter.create( + op.getLoc(), + RankedTensorType::get(expected_shape, input_ty.getElementType()), + op.input(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(expected_shape, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + // Last dim smaller than fft_length, zero-pad the input + } else if (input_ty.getShape().back() < fft_length) { + SmallVector no_padding(input_shape.size(), 0); + SmallVector padding(input_shape.size() - 1, 0); + padding.push_back(fft_length - input_shape.back()); + Value zero = + GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); + reshaped = rewriter.create( + loc, RankedTensorType::get(expected_shape, input_ty.getElementType()), + op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter), + GetI64ElementsAttr(padding, &rewriter), + GetI64ElementsAttr(no_padding, &rewriter)); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), reshaped, fft_string, + rewriter.getI64TensorAttr(fft_length)); + return success(); + } +}; + +using ConvertRFFTOp = ConvertFFTOp; +using ConvertIRFFTOp = ConvertFFTOp; + // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO // BatchNormGradOp for training and a sequence of binary ops for inference. // TODO(b/145536565): move to legalize_tf_patterns.td if it applies. @@ -1708,7 +1814,7 @@ class ConvertFusedBatchNormGradBase act = rewriter.create(loc, act, kernel_type); auto feature_dim_attr = - getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act); + getFeatureDimensionAttr(rewriter, op.data_format(), act); auto feature_dim = feature_dim_attr.getValue().getSExtValue(); // Gets the result values. @@ -1723,7 +1829,7 @@ class ConvertFusedBatchNormGradBase auto training_op = rewriter.create( loc, result_type, act, scale, mean, var, grad, op.epsilon(), - feature_dim_attr.getValue()); + feature_dim); x_backprop = rewriter.create(loc, training_op.getResult(), 0); @@ -1802,7 +1908,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { LogicalResult matchAndRewrite(FusedBatchNormOpT op, PatternRewriter &rewriter) const override { auto feature_dim = - getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); + getFeatureDimensionAttr(rewriter, op.data_format(), op.x()); auto input_type_tensor = op.x().getType().template cast(); auto input_element_type = input_type_tensor.getElementType(); @@ -1843,7 +1949,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { auto bn_train_op = rewriter.create( op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(), - op.epsilon(), feature_dim.getValue()); + op.epsilon(), feature_dim.getInt()); // HLO op outputs a tuple of tensors. Extract those results. auto bn_train_op_result = bn_train_op.getResult(); Value y_out = rewriter.create( @@ -1930,7 +2036,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { op.getLoc(), /*result_type=*/bn_train_input_type_tensor, bn_train_input, op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(), - feature_dim.getValue()); + feature_dim.getInt()); // Convert back to input type to stay aligned with expected output type // for TF op. @@ -2368,16 +2474,23 @@ class ConvertMaxPoolOp : public OpRewritePattern { Type element_type = op.input().getType().template cast().getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); + tensorflow::Padding padding; + if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + return failure(); + if (padding == tensorflow::Padding::EXPLICIT) { + return failure(); + } Location loc = op.getLoc(); - ConstOp init = GetMinValueForType(element_type, loc, &rewriter); + ConstOp init = GetScalarLimitConstOfType(element_type, loc, + hlo::kInfinityLowest, &rewriter); auto input_ty = op.input().getType().template dyn_cast(); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( - loc, op.getType(), op.input(), init.getResult(), - GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), + loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()), + GetI64ElementsAttr(op.strides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); BuildReduceBody(element_type, &reduce.body(), &rewriter); @@ -3078,7 +3191,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // axis. For instance, if there are 4 dims, we can support a // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no // other. - bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask(); + bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask()); if (!shrink_axis_mask_ok) return rewriter.notifyMatchFailure( op, @@ -3087,27 +3200,27 @@ class ConvertStridedSliceOp : public OpRewritePattern { // When begin/end values are dynamic, the ellipsis mask, if set, must refer // to the last dimension. - int ellipsis_mask = op.ellipsis_mask().getZExtValue(); + int ellipsis_mask = op.ellipsis_mask(); if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) return rewriter.notifyMatchFailure( op, "requires that ellipsis_mask, if set, refer to the last dimension of " "input (when begin/end values are dynamic)"); - APInt begin_mask = op.begin_mask(); - if (!begin_mask.isNullValue()) + uint64_t begin_mask = op.begin_mask(); + if (begin_mask) return rewriter.notifyMatchFailure( op, "requires that begin_mask is either set to 0 or not set when " "begin/end values are dynamic"); - APInt end_mask = op.end_mask(); - if (!end_mask.isNullValue()) + uint64_t end_mask = op.end_mask(); + if (end_mask) return rewriter.notifyMatchFailure( op, "requires that end_mask is either set to 0 or not set when begin/end " "values are dynamic"); - APInt new_axis_mask = op.new_axis_mask(); - if (!new_axis_mask.isNullValue()) + uint64_t new_axis_mask = op.new_axis_mask(); + if (new_axis_mask) return rewriter.notifyMatchFailure( op, "requires that new_axis_mask is either set to 0 or not set when " @@ -3620,7 +3733,8 @@ class ConvertMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMinValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, rewriter); } }; @@ -3637,7 +3751,8 @@ class ConvertMinOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMaxValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, rewriter); } }; @@ -3773,7 +3888,8 @@ class ConvertArgMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetMinValueForType(reduce_element_type, loc, &rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, &rewriter); } static StringRef GetDirection() { return "GT"; } @@ -4360,7 +4476,7 @@ class ConvertOneHotOp : public OpRewritePattern { } int64_t depth = depth_attr.getValue({}).getSExtValue(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis == -1) axis = indices_shape.size(); llvm::SmallVector broadcast_dims(indices_shape.size()); @@ -4636,7 +4752,7 @@ class ConvertUnpackOp : public OpRewritePattern { if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); + int64_t axis = op.axis(); if (axis < 0) axis += value_rank; // Parameters for constructing each slice. @@ -4712,7 +4828,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { auto output_type = RankedTensorType::get(output_shape, data_type.getElementType()); - // Broadccast the initial value for reduction. This will become the + // Broadcast the initial value for reduction. This will become the // 'operand' parameter to scatter to for the final scatter op. Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), op.getLoc(), &rewriter); @@ -4752,7 +4868,8 @@ class ConvertUnsortedSegmentMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMinValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest, + rewriter); } }; @@ -4765,7 +4882,8 @@ class ConvertUnsortedSegmentMinOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMaxValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax, + rewriter); } }; @@ -5007,7 +5125,12 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { SmallVector unpacked_indices_type( indices_type.getDimSize(0), RankedTensorType::get({}, indices_type.getElementType())); - auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0); + // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are + // required to have matching types. This rewrite rule creates + // DynamicUpdateSlice ops where the first "start index" is always i32 and + // subsequent ones are constructed based on zero_attr. Thus the type + // for zero_attr needs to be i32 as well. + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0); auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, indices, zero_attr); @@ -5071,26 +5194,25 @@ class ConvertXlaDynamicUpdateSliceOp } }; -/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting -/// appropriate window dimensions, with 'add' as the reduction function. The -/// input tensor needs to have a static shape, and 'axis' must be const. The -/// TableGen pattern is not used for this rewrite because it involves regions. -class ConvertCumsumOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template +class ConvertCumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::CumsumOp op, + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = op.x(); - auto input_type = input.getType().dyn_cast(); + auto input_type = input.getType().template dyn_cast(); if (!input_type || !input_type.hasStaticShape()) { return failure(); } - // TODO(jennik): Add support for the optional 'exclusive' and 'reverse' - // arguments. - if (op.exclusive() || op.reverse()) { - return failure(); - } + ArrayRef input_shape = input_type.getShape(); + int64_t rank = input_shape.size(); // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; @@ -5098,15 +5220,6 @@ class ConvertCumsumOp : public OpRewritePattern { return failure(); } - // Convert if we need to enlarge the element type's bitwidth to avoid - // precision loss. - Type input_element_type = input_type.getElementType(); - Type sum_element_type = GetSumAccumulationType(input_element_type); - input = rewriter.create(op.getLoc(), input, sum_element_type); - - ArrayRef input_shape = input_type.getShape(); - int64_t rank = input_shape.size(); - // Get the dimension to apply the reduction on, and offset properly if it is // negative. int64_t axis = (*axis_attr.begin()).getSExtValue(); @@ -5114,6 +5227,25 @@ class ConvertCumsumOp : public OpRewritePattern { axis += rank; } + // If we're supposed to sum things up in the reverse direction, we reverse + // the input and then later reverse the output. + if (op.reverse()) { + llvm::SmallVector dims_to_reverse({axis}); + input = rewriter.create( + op.getLoc(), op.getType(), input, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + // Convert if we need to enlarge the element type's bitwidth to avoid + // precision loss. + Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + + Type sum_element_type = GetSumAccumulationType(input_element_type); + input = rewriter.create(op.getLoc(), input, sum_element_type); + SmallVector window_dims(rank, 1); SmallVector window_strides(rank, 1); window_dims[axis] = input_shape[axis]; @@ -5124,8 +5256,9 @@ class ConvertCumsumOp : public OpRewritePattern { RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)), paddings); - Value init = - GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + int64_t init_value = (std::is_same::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); auto reduce = rewriter.create( op.getLoc(), input_type, input, init, @@ -5133,18 +5266,45 @@ class ConvertCumsumOp : public OpRewritePattern { GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); + BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); Value result = reduce.getResult(); + if (op.exclusive()) { + // In "exclusive" operation, the output will start with the "init" (0) + // values. There is no way to express that as a ReduceWindowOp, so run the + // normal operation, and then use a PadOp to add the 0 "column" on the + // left and cut away the last column on the right. + llvm::SmallVector low_padding(rank, 0); + llvm::SmallVector high_padding(rank, 0); + llvm::SmallVector interior_padding(rank, 0); + low_padding[axis] = 1; + high_padding[axis] = -1; + result = rewriter.create( + op.getLoc(), op.getType(), result, init, + GetI64ElementsAttr(low_padding, &rewriter), + GetI64ElementsAttr(high_padding, &rewriter), + GetI64ElementsAttr(interior_padding, &rewriter)); + } + // Convert back if we enlarged the element type's bitwidth. result = rewriter.create(op.getLoc(), result, input_element_type); + if (op.reverse()) { + llvm::SmallVector dims_to_reverse({axis}); + result = rewriter.create( + op.getLoc(), op.getType(), result, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + rewriter.replaceOp(op, result); return success(); } }; +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; + // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and // converting each dimension to a known integer type, and repacking into a final @@ -5173,6 +5333,101 @@ class ConvertShapeOp : public OpRewritePattern { } }; +class ConvertDynamicReshapeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto tensor = op.tensor(); + auto shape = op.shape(); + + auto tensor_ty = tensor.getType().cast(); + auto shape_ty = shape.getType().cast(); + auto result_ty = op.getType().cast(); + + if (!result_ty.hasRank() || !tensor_ty.hasRank() || !shape_ty.hasRank()) { + return failure(); + } + + // Handle with the static case. + if (result_ty.hasStaticShape()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, result_ty, tensor, + shape); + return success(); + } +}; + +class ConvertDynamicExpandDimsOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ExpandDimsOp op, + PatternRewriter &rewriter) const override { + auto input = op.input(); + auto input_ty = input.getType().cast(); + auto result_ty = op.getType().cast(); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr expand_dims_attr; + if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()), + input); + auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues()); + + llvm::SmallVector dims; + dims.resize(result_ty.getRank()); + + auto inserted_dim = expand_dims_attr.getValue({}) + .cast() + .getValue() + .getSExtValue(); + + // Handle the negative value use case. + if (inserted_dim < 0) { + inserted_dim += result_ty.getRank(); + // This means the value is completely incorrect, just return. + if (inserted_dim < 0) { + return failure(); + } + } + + dims[inserted_dim] = rewriter.create(op.getLoc(), 1); + + for (int i = 0; i < dims.size() - 1; i++) { + // Add the extracted dim. + auto index = rewriter.create(op.getLoc(), i); + auto dim = rewriter.create( + op.getLoc(), rewriter.getIndexType(), shape, index); + + dims[i >= inserted_dim ? i + 1 : i] = dim; + } + + auto from_extents = rewriter.create( + op.getLoc(), shape::ShapeType::get(op.getContext()), dims); + + auto to_extent_tensor = rewriter.create( + op.getLoc(), + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()), + from_extents); + + rewriter.replaceOpWithNewOp(op, result_ty, input, + to_extent_tensor); + return success(); + } +}; + // Converts a TF QR op to HLO. class ConvertQrOp : public OpRewritePattern { public: @@ -5672,7 +5927,7 @@ class ConvertQrOp : public OpRewritePattern { void EmitLegalizationErrors(Operation *op, const DenseSet &nonlegalized_ops) { // Track the legalization failures by mapping op name to information about - // that failure: the number of unlegalized occurances of the op, and one + // that failure: the number of unlegalized occurrences of the op, and one // example operation that failed. std::map> op_name_to_error_info; DenseSet error_ops; @@ -5714,9 +5969,14 @@ void EmitLegalizationErrors(Operation *op, // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed( - legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) + llvm::Optional tf2xla_fallback_device_type = llvm::None; + if (use_tf2xla_fallback_) { + tf2xla_fallback_device_type = device_type_; + } + if (failed(legalizeTF(getFunction(), allow_partial_conversion_, + legalize_chlo_, tf2xla_fallback_device_type))) { signalPassFailure(); + } } static PassRegistration pass( @@ -5726,19 +5986,38 @@ static PassRegistration pass( #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, - bool legalize_chlo) { +LogicalResult legalizeTF( + Operation *op, bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { MLIRContext *context = op->getContext(); - - // Add lowering patterns to the list. OwningRewritePatternList patterns; + // Note that the `OperationConverter` orders patterns lexicographically by: + // 1) Ascending legalization depth (i.e., minimum number of patterns necessary + // to arrive at conversion target). + // 2) Descending pattern benefit. + // 3) Order of patterns in `OwningRewritePatternList`. + + // Add TF->HLO legalization patterns. PopulateLegalizeTfPatterns(context, &patterns); + // Add TF->TF lowering patterns. + TF::PopulateLoweringTFPatterns(context, &patterns); + + // Add TF->HLO legalization patterns via TF2XLA fallback. + if (tf2xla_fallback_device_type.hasValue()) { + PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(), + patterns); + } + // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. if (legalize_chlo) { chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); } + // ConstantLike op is convenient to create splat constants, but is + // canonicalized to plain HLO constant if statically shaped. Add the + // canonicalization pattern to pattern list to enable multi-hop lowering. + chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); ConversionTarget target(*context); if (legalize_chlo) { @@ -5773,28 +6052,25 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, void PopulateLegalizeTfPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { populateWithGenerated(context, patterns); - - // Add patterns that lower some of the high level TensorFlow ops to lower - // level TensorFlow ops. So, we don't have to target all the TensorFlow ops - // here for lowering to HLO. - TF::PopulateLoweringTFPatterns(context, patterns); patterns->insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, - ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, - ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, - ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, - ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, - ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, - ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, - ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, + ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp, ConvertEinsumOp, + ConvertRFFTOp, ConvertIRFFTOp, ConvertFusedBatchNormGradOp, + ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, + ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp, + ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, + ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, @@ -5806,8 +6082,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, } std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion, bool legalize_chlo) { - return std::make_unique(allow_partial_conversion, legalize_chlo); + bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { + return std::make_unique(allow_partial_conversion, legalize_chlo, + tf2xla_fallback_device_type); } } // end namespace mhlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc index 588e31ab669..6320ad2032b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -22,15 +22,20 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -49,45 +54,104 @@ const char kXlaHostTransferOriginalTypeAttr[] = "_xla_host_transfer_original_type"; // A pass that legalizes TF/XLA communication ops, propagate their respective -// tokens (for ordering), and rewrite their respective functions when necessary. +// tokens (for ordering), and rewrite their respective functions and control +// flow ops when necessary. // Note, this currently does not handle nested modules/functions or region based -// ops (e.g. control flow). +// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`). class LegalizeTFCommunication : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override; }; -// Checks if a function has any communication ops. -bool HasCommunicationOps(FuncOp func) { - auto result = func.walk([](Operation* op) { - if (isa(op)) +// Checks if an op is a TF/XLA communication op. +bool IsCommunicationOp(Operation* op) { + return isa(op); +} + +// Checks if an op is a supported HLO control flow op. +bool IsControlFlowOp(Operation* op) { return isa(op); } + +// Collects control flow op ancestors of a given op, up until FuncOp. If any +// ancestor is not a control flow op or a FuncOp, or of a single block region, +// an error will be returned. +LogicalResult GetControlFlowAncestors( + Operation* op, llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks) { + Block* block = op->getBlock(); + Operation* parent = block->getParentOp(); + while (block && parent && !isa(parent)) { + if (!IsControlFlowOp(parent)) + return op->emitOpError() + << "expects ancestor(s) to be of ['" << IfOp::getOperationName() + << "', '" << FuncOp::getOperationName() << "']"; + + if (!llvm::hasSingleElement(block->getParent()->getBlocks())) + return op->emitOpError() << "expects single block region ancestor(s)"; + + control_flow_ops.insert(parent); + control_flow_blocks.insert(block); + + parent = block->getParentOp(); + block = parent->getBlock(); + } + return success(); +} + +// Finds communication ops in a function. `control_flow_ops` and +// `control_flow_blocks` will be populated with control flow op ancestors for +// every communication op. +LogicalResult FindCommunicationOps( + FuncOp func, llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks, + bool& has_communication_ops) { + auto result = func.walk([&](Operation* op) { + if (!IsCommunicationOp(op)) return WalkResult::advance(); + has_communication_ops = true; + if (failed( + GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks))) return WalkResult::interrupt(); return WalkResult::advance(); }); - return result.wasInterrupted(); + return failure(result.wasInterrupted()); } -// Helper struct holding a function and optional cloned version. If `clone` is -// set, function calls to `original` will be replaced with `clone`. -struct FuncAndClone { +// Helper struct holding a function to be rewritten, it's control flow ops that +// lead to a communication op or function call with a communication op +// (transitively), and an optional clone of itself. If `clone` is set, function +// calls to `original` will be replaced with `clone`. +struct FuncToRewrite { FuncOp original; + llvm::SmallPtrSet control_flow_ops; + llvm::SmallPtrSet control_flow_blocks; FuncOp clone; }; // Finds all functions that need to be rewritten with communication ops and // and associated tokens. -llvm::SmallDenseMap GetFunctionsToRewrite( - ModuleOp module) { +LogicalResult GetFunctionsToRewrite( + ModuleOp module, + llvm::SmallDenseMap& funcs_to_rewrite) { // Find functions containing communication ops. - llvm::SmallDenseMap funcs; SmallVector funcs_to_visit; for (FuncOp func : module.getOps()) { - if (HasCommunicationOps(func)) { - funcs.insert({func.getName(), {func, /*clone=*/nullptr}}); - funcs_to_visit.push_back(func); - } + FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{}, + /*control_flow_blocks=*/{}, + /*clone=*/nullptr}; + bool has_communication_ops = false; + if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks, + has_communication_ops))) + return failure(); + + if (!has_communication_ops) continue; + funcs_to_rewrite.insert({func.getName(), func_to_rewrite}); + funcs_to_visit.push_back(func); } // Find functions that call functions with communication ops, transitively. @@ -100,13 +164,30 @@ llvm::SmallDenseMap GetFunctionsToRewrite( // Only `mlir::CallOp` is supported as this requires knowing how to // rewrite arguments and results to a function. if (!isa(use.getUser())) continue; - auto caller_func = use.getUser()->getParentOfType(); - if (!caller_func) continue; - if (funcs - .insert( - {caller_func.getName(), {caller_func, /*clone=*/nullptr}}) - .second) - new_funcs_to_visit.push_back(caller_func); + auto caller_parent_func = use.getUser()->getParentOfType(); + if (!caller_parent_func) continue; + + FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func, + /*control_flow_ops=*/{}, + /*control_flow_blocks=*/{}, + /*clone=*/nullptr}; + if (failed(GetControlFlowAncestors( + use.getUser(), func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks))) + return failure(); + + auto it = funcs_to_rewrite.insert( + {caller_parent_func.getName(), func_to_rewrite}); + if (it.second) { + new_funcs_to_visit.push_back(caller_parent_func); + } else { + it.first->getSecond().control_flow_ops.insert( + func_to_rewrite.control_flow_ops.begin(), + func_to_rewrite.control_flow_ops.end()); + it.first->getSecond().control_flow_blocks.insert( + func_to_rewrite.control_flow_blocks.begin(), + func_to_rewrite.control_flow_blocks.end()); + } } } @@ -116,8 +197,9 @@ llvm::SmallDenseMap GetFunctionsToRewrite( // Clone public functions that need to be rewritten. Function calls to this // function will be replaced with the cloned function. SymbolTable symbol_table(module); - for (auto& func : funcs) { - if (func.getSecond().original.isPublic()) { + for (auto& func : funcs_to_rewrite) { + if (func.getSecond().original.isPublic() && + !func.getSecond().original.symbolKnownUseEmpty(module)) { auto clone = func.getSecond().original.clone(); clone.setVisibility(SymbolTable::Visibility::Private); symbol_table.insert(clone); @@ -125,7 +207,7 @@ llvm::SmallDenseMap GetFunctionsToRewrite( } } - return funcs; + return success(); } // Assigns op sharding to an op for a given device core. @@ -137,11 +219,17 @@ void SetOpSharding(Operation* op, int64_t tpu_core) { } // Assigns frontend attributes holding information about data type and -// TensorFlow rendezvous channel name. -void SetFrontendAttributes(Operation* op, StringRef key, Type type) { +// TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is +// handled differently as individual names are used per data send and receive. +void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, + Type type, bool device_to_host) { MLIRContext* context = op->getContext(); - auto rendezvous_name = StringAttr::get(key, context); + std::string formatted_key = + device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str() + : llvm::formatv("{0}_htod_{1}", key, index).str(); + + auto rendezvous_name = StringAttr::get(formatted_key, context); auto rendezvous_name_attr = NamedAttribute( Identifier::get(kXlaHostTransferRendezvousNameAttr, context), rendezvous_name); @@ -161,24 +249,10 @@ void SetFrontendAttributes(Operation* op, StringRef key, Type type) { op->setAttr(kFrontendAttributesAttr, frontend_attributes); } -// Assigns frontend attributes holding information about data type and -// TensorFlow rendezvous channel name specific to `tf._XlaHostComputeMlir`. -// TensorFlow rendezvous channel name is handled differently as individual names -// are used per data send and receive. -void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, - Type type, bool device_to_host) { - std::string formatted_key = - device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str() - : llvm::formatv("{0}_htod_{1}", key, index).str(); - - return SetFrontendAttributes(op, formatted_key, type); -} - -// Creates a `mhlo.send` op for sending value `operand`. If `index` is set, -// `key` will be rewritten with a suffix and index. If `tpu_core` is set, op -// sharding for the respective device will be set. +// Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set, +// op sharding for the respective device will be set. Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value operand, StringRef key, const Optional& index, + Value operand, StringRef key, size_t index, const Optional& tpu_core, Value token) { // type 2 == DEVICE_TO_HOST auto channel_handle = ChannelHandle::get( @@ -188,23 +262,18 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, loc, token.getType(), operand, token, channel_handle, /*is_host_transfer=*/builder.getBoolAttr(true)); - if (index) { - SetFrontendAttributes(send, *index, key, operand.getType(), - /*device_to_host=*/true); - } else { - SetFrontendAttributes(send, key, operand.getType()); - } + SetFrontendAttributes(send, index, key, operand.getType(), + /*device_to_host=*/true); if (tpu_core) SetOpSharding(send, *tpu_core); return send.getResult(); } -// Creates a `mhlo.recv` op for receiving a value. If `index` is set, `key` will -// be rewritten with a suffix and index. If `tpu_core` is set, op sharding for -// the respective device will be set. +// Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op +// sharding for the respective device will be set. Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value result, StringRef key, const Optional& index, + Value result, StringRef key, size_t index, const Optional& tpu_core, Value token) { // type 3 == HOST_TO_DEVICE auto channel_handle = ChannelHandle::get( @@ -216,12 +285,10 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, auto recv = builder.create(loc, recv_result_type, token, channel_handle, /*is_host_transfer=*/builder.getBoolAttr(true)); - if (index) { - SetFrontendAttributes(recv, *index, key, result_type, - /*device_to_host=*/false); - } else { - SetFrontendAttributes(recv, key, result.getType()); - } + + SetFrontendAttributes(recv, index, key, result_type, + /*device_to_host=*/false); + if (tpu_core) SetOpSharding(recv, *tpu_core); auto get_tuple_element = @@ -291,7 +358,7 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, builder.setInsertionPoint(send_to_host); token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), send_to_host.input(), send_to_host.key(), - /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + /*index=*/0, /*tpu_core=*/llvm::None, token); send_to_host.erase(); return token; @@ -303,7 +370,7 @@ Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, builder.setInsertionPoint(recv_from_host); token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), recv_from_host.output(), recv_from_host.key(), - /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + /*index=*/0, /*tpu_core=*/llvm::None, token); recv_from_host.erase(); return token; @@ -329,94 +396,489 @@ Value RewriteCallOp(OpBuilder& builder, CallOp call, return new_call.getResults().back(); } -// Updates function terminator and type if a token is to be emitted by the -// function. -void RewriteFunctionTerminatorAndUpdateType(OpBuilder& builder, FuncOp func, - Block& func_body, Value token) { - // If the function signature is changed, update to emit a token and update - // the function type. - Operation* terminator = func_body.getTerminator(); - auto new_results = llvm::to_vector<4>(terminator->getOperands()); - new_results.push_back(token); - builder.setInsertionPoint(terminator); - auto new_return = - builder.create(terminator->getLoc(), new_results); - terminator->erase(); +// Helper struct holding state of which op to visit to next. If `op` is in a +// control flow op region, `region_idx` will be set with the respective region +// index. `token` will be current token from the last communication op/control +// flow op transitive communication ops. +struct OpVisitorState { + Optional region_idx; + Value token; + Operation* op; +}; +// Creates a tuple from a sequence of values. +Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef operands) { + return builder.create(loc, operands).getResult(); +} + +// Replaces a value `value` with a new value but the token attached. If `value` +// is not a tuple, a new tuple is formed with `token`. If `value` is a tuple, +// `value` is extended instead. New tuple values created are cached. +Value GetValueWithToken(OpBuilder& builder, Value value, Value token, + llvm::SmallDenseMap& rewritten_values) { + // If value with token already exists, reuse it. + auto it = rewritten_values.find(value); + if (it != rewritten_values.end()) return it->getSecond(); + + auto create_tuple = [&](ArrayRef operands) { + auto new_result = CreateTuple(builder, value.getLoc(), operands); + rewritten_values.insert({value, new_result}); + return new_result; + }; + + auto tuple_type = value.getType().dyn_cast(); + // `value` is not a tuple, create a new tuple. + if (!tuple_type) return create_tuple({value, token}); + + // Extend tuple if `value` is a tuple. + // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack + // the tuple. + if (auto tuple_op = value.getDefiningOp()) { + auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands()); + tuple_operands.push_back(token); + return create_tuple(tuple_operands); + } + + // `value` is not created via a `mhlo.tuple` directly, unpack individual + // elements directly with `mhlo.get_tuple_element`. + SmallVector tuple_operands; + for (auto idx : llvm::seq(0, tuple_type.getTypes().size())) + tuple_operands.push_back( + builder.create(value.getLoc(), value, idx) + .getResult()); + + tuple_operands.push_back(token); + return create_tuple(tuple_operands); +} + +// Extends a type to include a `mhlo.token` type. If `type` is not a tuple type, +// a new tuple type with `type` and `mhlo.token` type is created instead. +TupleType GetTypeWithToken(OpBuilder& builder, Type type) { + auto token_type = TokenType::get(builder.getContext()); + if (auto tuple_type = type.dyn_cast()) { + auto result_types = llvm::to_vector<4>(tuple_type.getTypes()); + result_types.push_back(token_type); + return builder.getTupleType(result_types); + } + + return builder.getTupleType({type, token_type}); +} + +// Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0 +// to `end`, exclusive. +Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) { + SmallVector tuple_operands; + for (auto idx : llvm::seq(0, end)) + tuple_operands.push_back( + builder.create(value.getLoc(), value, idx) + .getResult()); + + return CreateTuple(builder, value.getLoc(), tuple_operands); +} + +// Replaces uses of `value` with `replacement`. If `value` is not a tuple type, +// an explicit `mhlo.get_tuple_element` is created to unpack the tuple and +// return the first element. Otherwise, `mhlo.get_tuple_element` users are +// simply updated with `replacement`, and all other users are updated with a +// slice of `replacement`. +void ReplaceWithTupleResult(OpBuilder& builder, Value value, + Value replacement) { + auto tuple_type = value.getType().dyn_cast(); + if (!tuple_type) { + if (!value.use_empty()) { + auto new_element = builder.create(replacement.getLoc(), + replacement, 0); + value.replaceAllUsesWith(new_element.getResult()); + } + return; + } + + Value sub_tuple; + for (auto& use : llvm::make_early_inc_range(value.getUses())) { + if (isa(use.getOwner())) { + use.set(replacement); + continue; + } + + if (!sub_tuple) + sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size()); + + use.set(sub_tuple); + } +} + +// Replaces control flow op block single block argument with new block argument +// of type `new_type` (tuple type). The last element of the new block argument +// (token) is returned. +Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block, + Type token_type) { + assert(block.getNumArguments() == 1); + builder.setInsertionPointToStart(&block); + auto new_arg = block.addArgument(token_type); + ReplaceWithTupleResult(builder, block.getArgument(0), new_arg); + block.eraseArgument(0); + return builder + .create(new_arg.getLoc(), new_arg, + token_type.cast().size() - 1) + .getResult(); +} + +// Updates control flow op terminator with an extra element `token`. If the +// original return value is not a tuple, a new tuple is formed. Otherwise the +// tuple is extended. +void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator, + Value token) { + assert(terminator->getNumOperands() == 1); + assert(terminator->getBlock()->getNumArguments() == 1); + // `mhlo.while` cond terminator does not need to be rewritten as it always + // returns a tensor predicate value. + if (auto while_parent = dyn_cast_or_null(terminator->getParentOp())) + if (terminator->getParentRegion() == &while_parent.cond()) return; + + builder.setInsertionPoint(terminator); + llvm::SmallDenseMap rewritten_operands; + Value new_result = GetValueWithToken(builder, terminator->getOperand(0), + token, rewritten_operands); + terminator->setOperand(0, new_result); +} + +// Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to +// the op for all of its regions are extended to have an extra operand `token`. +void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if, + SmallVectorImpl& ops_to_visit, + Value token) { + llvm::SmallDenseMap rewritten_operands; + + // Rewrite all region operands to have an extra operand `token`. + Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(), + token, rewritten_operands); + Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(), + token, rewritten_operands); + + auto new_result_type = GetTypeWithToken(builder, region_if.getType()); + + // Create new `mhlo.if` op with extra token operands and result. + auto new_if = builder.create(region_if.getLoc(), new_result_type, + region_if.pred(), new_true_operand, + new_false_operand); + + // Move all regions from the old `mhlo.if` op to its replacement. + new_if.true_branch().takeBody(region_if.true_branch()); + new_if.false_branch().takeBody(region_if.false_branch()); + + // Forward result from old `mhlo.if` with replacement, and unpack result when + // necessary. + ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult()); + + auto new_token = builder.create( + new_if.getLoc(), new_if.getResult(), + new_if.getResult().getType().cast().size() - 1); + + region_if.erase(); + + // Remove leftover operands to old `mhlo.if` if they have no uses. + for (auto& rewritten_operand : rewritten_operands) + if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp()) + if (tuple_op.use_empty()) tuple_op.erase(); + + // Next op to visit. The replacement is visited but at its first region. The + // token result of the new region if is propagated. + ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if}); +} + +// Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a +// `mhlo.token`. The block argument is updated to have an extra `mhlo.token` +// element. If the region block is to be rewritten, the next op to visit is set +// to the first op in the block. Otherwise the terminator is updated to forward +// `token`. +void RewriteControlFlowOpRegion( + OpBuilder& builder, Operation* region_op, unsigned region_idx, + Type block_arg_type, SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, Value token) { + ops_to_visit.push_back({region_idx + 1, token, region_op}); + + Region& region = region_op->getRegion(region_idx); + assert(llvm::hasSingleElement(region)); + + auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(), + block_arg_type); + + if (control_flow_blocks.contains(®ion.front())) { + ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token, + block_token.getDefiningOp()->getNextNode()}); + return; + } + + RewriteControlFlowTerminator(builder, region.front().getTerminator(), + block_token); +} + +// Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op +// operands and results are rewritten. If `region_idx` is set, region +// `region_idx` is rewritten to take in and return an additional token. Returns +// true if the op or its region was rewritten. +bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if, + Optional region_idx, + SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, + Value token) { + builder.setInsertionPoint(region_if); + + if (!region_idx) { + RewriteRegionIfOp(builder, region_if, ops_to_visit, token); + return true; + } + + if (*region_idx < region_if.getNumRegions()) { + RewriteControlFlowOpRegion(builder, region_if, *region_idx, + region_if.getOperand(*region_idx + 1).getType(), + ops_to_visit, control_flow_blocks, token); + return true; + } + + return false; +} + +// Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to +// the op for all of its regions are extended to have an extra operand `token`. +void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while, + SmallVectorImpl& ops_to_visit, + Value token) { + llvm::SmallDenseMap rewritten_operands; + + // Rewrite region operand to have an extra operand `token`. + Value new_val_operand = + GetValueWithToken(builder, region_while.val(), token, rewritten_operands); + + auto new_result_type = GetTypeWithToken(builder, region_while.getType()); + + // Create new `mhlo.while` op with extra token operand and result. + auto new_while = builder.create(region_while.getLoc(), + new_result_type, new_val_operand); + + // Move all regions from the old `mhlo.while` op to its replacement. + new_while.cond().takeBody(region_while.cond()); + new_while.body().takeBody(region_while.body()); + + // Forward result from old `mhlo.while` with replacement, and unpack result + // when necessary. + ReplaceWithTupleResult(builder, region_while.getResult(), + new_while.getResult()); + + auto new_token = builder.create( + new_while.getLoc(), new_while.getResult(), + new_while.getResult().getType().cast().size() - 1); + + region_while.erase(); + + // Remove leftover operands to old `mhlo.while` if they have no uses. + for (auto& rewritten_operand : rewritten_operands) + if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp()) + if (tuple_op.use_empty()) tuple_op.erase(); + + // Next op to visit. The replacement is visited but at its first region. The + // token result of the new region if is propagated. + ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while}); +} + +// Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op +// operands and results are rewritten. If `region_idx` is set, region +// `region_idx` is rewritten to take in and return an additional token. Returns +// true if the op or its region was rewritten. +bool ProcessRegionWhileOp( + OpBuilder& builder, WhileOp region_while, Optional region_idx, + SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, Value token) { + builder.setInsertionPoint(region_while); + + if (!region_idx) { + RewriteRegionWhileOp(builder, region_while, ops_to_visit, token); + return true; + } + + if (*region_idx < region_while.getNumRegions()) { + RewriteControlFlowOpRegion(builder, region_while, *region_idx, + region_while.val().getType(), ops_to_visit, + control_flow_blocks, token); + return true; + } + + return false; +} + +// Updates function type based on current function body block arguments and +// terminator operand types. +void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) { auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes()); - auto new_result_types = llvm::to_vector<4>(new_return.getOperandTypes()); + auto new_result_types = + llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes()); func.setType(FunctionType::get(new_argument_types, new_result_types, builder.getContext())); } -// Rewrites a function body and communication ops inside. The function may -// either be rewritten to create a token or take in and return a token, -// depending on its visibility and if there are any callers. +// Replaces a function terminator `return` with another `return` that has an +// extra `mhlo.token` operand. +void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator, + Value token) { + auto new_results = llvm::to_vector<4>(terminator.getOperands()); + new_results.push_back(token); + builder.setInsertionPoint(terminator); + builder.create(terminator.getLoc(), new_results); + terminator.erase(); +} + +// Rewrites a function body and communication ops inside. Region control flow +// are updated when necessary, to propagate tokens. The function may either be +// rewritten to create a token or take in and return a token, depending on its +// visibility and if there are any callers. LogicalResult RewriteFunction( OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func, - const llvm::SmallDenseMap& funcs) { + const llvm::SmallDenseMap& funcs, + const llvm::SmallPtrSetImpl& control_flow_ops, + const llvm::SmallPtrSetImpl& control_flow_blocks, bool is_clone) { MLIRContext* context = module.getContext(); if (!llvm::hasSingleElement(func.getBody())) return func.emitError() << "'" << FuncOp::getOperationName() << "' ops with more than one block are not supported"; - bool rewrite_block = !func.isPublic() && !func.symbolKnownUseEmpty(module); + bool rewrite_block = + is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module)); Block& func_body = func.front(); builder.setInsertionPointToStart(&func_body); - auto token_type = mlir::mhlo::TokenType::get(context); + auto token_type = TokenType::get(context); // If a function is public, it's signature should not be modified, and instead // a token will be created. Otherwise a token block argument is inserted. - Value token = rewrite_block - ? func_body.addArgument(token_type) + Value init_token = + rewrite_block ? func_body.addArgument(token_type) : builder.create(func.getLoc(), token_type) .getResult(); - for (Operation& op : llvm::make_early_inc_range(func_body)) { - if (auto host_compute = dyn_cast(op)) { + // Stack to keep track of region based control flow op nesting and current + // op to visit. + SmallVector ops_to_visit{ + {/*region_idx=*/llvm::None, init_token, &func_body.front()}}; + + while (!ops_to_visit.empty()) { + OpVisitorState op_to_visit = ops_to_visit.pop_back_val(); + Operation* curr_op = op_to_visit.op; + + Value token = op_to_visit.token; + // Ops may be removed, so the next op is kept track of beforehand. + Operation* next_op = curr_op->getNextNode(); + + if (auto host_compute = dyn_cast(curr_op)) { token = RewriteHostComputeOp(builder, channel_id, host_compute, token); - } else if (auto send_to_host = dyn_cast(op)) { + } else if (auto send_to_host = dyn_cast(curr_op)) { token = RewriteSendToHostOp(builder, channel_id, send_to_host, token); - } else if (auto recv_from_host = dyn_cast(op)) { + } else if (auto recv_from_host = dyn_cast(curr_op)) { token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token); - } else if (auto call = dyn_cast(op)) { + } else if (auto call = dyn_cast(curr_op)) { // Only `mlir::CallOp` is supported as this requires knowing how to // rewrite arguments and results to a function. auto it = funcs.find(call.getCallee()); - if (it == funcs.end()) continue; - FuncOp clone = it->getSecond().clone; - Optional symbol_name = - clone ? Optional(clone.getName()) : llvm::None; - // If the function being called is to be cloned, update the call to also - // point to the cloned function. - token = RewriteCallOp(builder, call, symbol_name, token); + if (it != funcs.end()) { + FuncOp clone = it->getSecond().clone; + Optional symbol_name = + clone ? Optional(clone.getName()) : llvm::None; + // If the function being called is to be cloned, update the call to also + // point to the cloned function. + token = RewriteCallOp(builder, call, symbol_name, token); + } + } else if (auto region_if = dyn_cast(curr_op)) { + if (op_to_visit.region_idx || control_flow_ops.contains(region_if)) + if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx, + ops_to_visit, control_flow_blocks, token)) + continue; + } else if (auto region_while = dyn_cast(curr_op)) { + if (op_to_visit.region_idx || control_flow_ops.contains(region_while)) + if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx, + ops_to_visit, control_flow_blocks, token)) + continue; + } else if (auto region_terminator = dyn_cast(curr_op)) { + RewriteControlFlowTerminator(builder, region_terminator, token); + // There is no next op afer the control flow op terminator, simply let + // stack have one less element. + continue; + } else if (auto func_terminator = dyn_cast(curr_op)) { + if (rewrite_block) + RewriteFunctionTerminator(builder, func_terminator, token); + + // There is no next op afer the function terminator, simply let stack have + // one less element/be empty. + continue; } + + // Visit next op. + ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op}); } - if (rewrite_block) - RewriteFunctionTerminatorAndUpdateType(builder, func, func_body, token); + if (rewrite_block) UpdateFunctionType(builder, func, func_body); return success(); } +// Checks if a function call is pointing to a function with communication ops. +bool IsFunctionCallWithCommunication( + Operation* op, + const llvm::SmallDenseMap& funcs_to_rewrite) { + if (auto call = dyn_cast(op)) + return funcs_to_rewrite.count(call.callee()); + + return false; +} + +// Collects all control flow op ancestors of communication ops or function calls +// with communication ops (transitively). +void GetCommunicationControlFlowOps( + FuncOp func, + const llvm::SmallDenseMap& funcs_to_rewrite, + llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks) { + func.walk([&](Operation* op) { + if (IsCommunicationOp(op) || + IsFunctionCallWithCommunication(op, funcs_to_rewrite)) + if (failed(GetControlFlowAncestors(op, control_flow_ops, + control_flow_blocks))) + llvm_unreachable( + "checking original function for control flow ancestors should have " + "errored first"); + }); +} + void LegalizeTFCommunication::runOnOperation() { auto module = getOperation(); - llvm::SmallDenseMap funcs = - GetFunctionsToRewrite(module); + llvm::SmallDenseMap funcs_to_rewrite; + if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite))) + return signalPassFailure(); // Module level counter to make sure Channel Id's are unique. int64_t channel_id = 1; OpBuilder builder(&getContext()); - for (const auto& func_and_name : funcs) { - FuncOp func = func_and_name.getSecond().original; - if (failed(RewriteFunction(builder, channel_id, module, func, funcs))) + for (const auto& func_and_name : funcs_to_rewrite) { + const auto& func_to_rewrite = func_and_name.getSecond(); + FuncOp func = func_to_rewrite.original; + if (failed(RewriteFunction(builder, channel_id, module, func, + funcs_to_rewrite, + func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks, + /*is_clone=*/false))) return signalPassFailure(); FuncOp clone = func_and_name.getSecond().clone; if (!clone) continue; - if (failed(RewriteFunction(builder, channel_id, module, clone, funcs))) - return signalPassFailure(); + llvm::SmallPtrSet clone_control_flow_ops; + llvm::SmallPtrSet clone_control_flow_blocks; + GetCommunicationControlFlowOps(clone, funcs_to_rewrite, + clone_control_flow_ops, + clone_control_flow_blocks); + if (failed(RewriteFunction(builder, channel_id, module, clone, + funcs_to_rewrite, clone_control_flow_ops, + clone_control_flow_blocks, + /*is_clone=*/true))) + llvm_unreachable( + "rewriting of original function should have errored first"); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 760252331e0..4e76baa6805 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -119,8 +119,8 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo // return op. - ImportXlaRegion(op.then_func(), &if_op.true_branch(), loc); - ImportXlaRegion(op.else_func(), &if_op.false_branch(), loc); + ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc); + ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc); // De-tuple the results of the xla hlo if result. Detuple(if_op.getResult(), op.getResults(), &builder); @@ -172,8 +172,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { // Import the regions for both the cond and body. These regions must be // updated to tuple the return results together and use the xla hlo return op. - ImportXlaRegion(op.body_func(), &while_op.body(), loc); - ImportXlaRegion(op.cond_func(), &while_op.cond(), loc, + ImportXlaRegion(op.body_function(), &while_op.body(), loc); + ImportXlaRegion(op.cond_function(), &while_op.cond(), loc, /*tuple_return=*/false); // De-tuple the results of the xla hlo while. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 0ef62deed7d..b1460421f16 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -31,7 +31,7 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; //===----------------------------------------------------------------------===// def FeatureDimension : NativeCodeCall< - "getFeatureDimensionAttr($_builder, $0, $1)">; + "getFeatureDimensionAttr($_builder, $0.getValue(), $1)">; def FalseBoolAttr : AttrConstraint>; def TrueBoolAttr : AttrConstraint>; @@ -51,6 +51,10 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "$0, (*$1.begin()).getType().cast().getRank(), " "&$_builder)">; +def CastElementsToI64Elements : NativeCodeCall< + "hlo::ConvertElementsAttr(" + "$0, $_builder.getIntegerType(64)).cast()">; + def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, $exponential_avg_factor, $data_format, @@ -82,7 +86,7 @@ def AreBroadcastCompatible : Constraint, "types must be broadcastable">; class DirectBinaryPat - : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], @@ -128,7 +132,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // return x / y; // } // -// BraodcastToDimensions is used to compute the broadcast attr to higher +// BroadcastToDimensions is used to compute the broadcast attr to higher // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') // without returning the broadcast of 'r' to broadcast('l', 'r'). // @@ -143,14 +147,14 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLOClient_BroadcastDivOp - (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), (HLOClient_BroadcastSubOp (HLO_AbsOp $r), (HLO_ConstOp (GetScalarOfType<1> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), - (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), + (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), [(SignedIntTensor $l)]>; // Performs a substitution of FloorMod designed to correct for possibly negative @@ -175,8 +179,8 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLOClient_BroadcastAddOp $r, - $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; + (HLOClient_BroadcastAddOp $r, + $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// // Logical & bitwise binary op patterns. @@ -255,12 +259,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), [(HasRankedFirstOperand $inputs)]>; //===----------------------------------------------------------------------===// -// CrossReplicaSum op patterns. +// CollectivePermute op patterns. //===----------------------------------------------------------------------===// -def CastElementsToI64Elements : NativeCodeCall< - "hlo::ConvertElementsAttr(" - "$0, $_builder.getIntegerType(64)).cast()">; +def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)), + (HLO_CollectivePermuteOp $input, + (CastElementsToI64Elements $source_target_pairs))>; + +//===----------------------------------------------------------------------===// +// CrossReplicaSum op patterns. +//===----------------------------------------------------------------------===// def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), (HLO_CrossReplicaSumOp $input, @@ -277,9 +285,19 @@ def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), // FFT op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_RFFTOp $input, (TF_ConstOp I32ElementsAttr:$fft_length)), - (HLO_FftOp $input, HLO_FFT_TYPE_RFFT, - (CastElementsToI64Elements $fft_length))>; +def GetInnerDimFromValue : NativeCodeCall< + "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + +def CheckInnerDimStatic + : Constraint(), &$_builder)">>; + +def : Pat<(TF_FFTOp:$res $input), + (HLO_FftOp $input, HLO_FFT_TYPE_FFT, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +def : Pat<(TF_IFFTOp:$res $input), + (HLO_FftOp $input, HLO_FFT_TYPE_IFFT, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; //===----------------------------------------------------------------------===// // GatherV2 op patterns. @@ -427,6 +445,35 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (TensorCastOp (HLO_ConstOp $value)), [(HLO_Tensor $res)]>; +//===----------------------------------------------------------------------===// +// Elu op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_EluOp AnyRankedTensor:$features), + (HLO_SelectOp + (HLOClient_BroadcastCompareOp + $features, + (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + HLO_COMPARISON_DIRECTION_GT), + $features, + (HLO_Expm1Op $features))>; + +def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), + (HLO_SelectOp + (HLOClient_BroadcastCompareOp + $features, + (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + HLO_COMPARISON_DIRECTION_GT), + $gradients, + (HLO_MulOp + $gradients, + (HLOClient_BroadcastAddOp + $features, + (HLO_ConstOp:$one (GetScalarOfType<1> $features)), + (BinBroadcastDimensions $one, $features))))>; + //===----------------------------------------------------------------------===// // Relu op patterns. //===----------------------------------------------------------------------===// @@ -542,24 +589,12 @@ foreach Mapping = [ [TF_SinOp, HLO_SinOp], [TF_SqrtOp, HLO_SqrtOp], [TF_TanhOp, HLO_TanhOp], + [TF_TanOp, HLOClient_TanOp], ] in { def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input)>; } -// Expand acos to MHLO dialect as follows: -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 -// = pi if x == -1 -def : Pat<(HLOClient_AcosOp $input), (HLO_SelectOp - (HLO_CompareOp $input, (HLO_ConstOp (ConstantSplat<"0"> $input)), - HLO_COMPARISON_DIRECTION_NE), - (HLO_MulOp (HLO_ConstOp (ConstantSplat<"2"> $input)), - (HLO_Atan2Op (HLO_SqrtOp (HLO_SubOp - (HLO_ConstOp (ConstantSplat<"1"> $input)), - (HLO_MulOp $input, $input))), - (HLO_AddOp (HLO_ConstOp (ConstantSplat<"1"> $input)), $input))), - (HLO_ConstOp (ConstantSplat<"M_PI"> $input)))>; - // TODO(bixia): Lower Cast with a Complex type source operand or with // Truncate=True for floating point value conversions. def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), @@ -594,6 +629,9 @@ def : Pat<(TF_BitcastOp:$res HLO_Tensor:$arg), (HLO_BitcastConvertOp $arg), [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>; +// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic +// and going to MHLO. + //===----------------------------------------------------------------------===// // Random ops. //===----------------------------------------------------------------------===// @@ -657,3 +695,19 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), ), (replaceWithValue $output) ]>; + +//===----------------------------------------------------------------------===// +// XlaGather op. +//===----------------------------------------------------------------------===// + +def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; + +def HasValidGatherDims : Constraint>; + +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (TF_ConstOp $slice_sizes), + $dimension_numbers, $indices_are_sorted), + (HLO_GatherOp $operand, $start_indices, + (ToGatherDimNumsAttr $dimension_numbers), + (CastElementsToI64Elements $slice_sizes), + $indices_are_sorted), + [(HasValidGatherDims $dimension_numbers)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index bb50fc198c8..b06edcd3db8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -74,17 +75,14 @@ limitations under the License. namespace mlir { namespace mhlo { -namespace { -template -using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok - -static bool IsOpAllowlisted(Operation* op) { +bool IsOpAllowedTf2XlaFallback(Operation* op) { // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. // clang-format off + static llvm::SmallDenseSet ops = { TypeID::get(), TypeID::get(), @@ -104,6 +102,11 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -112,12 +115,17 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -126,6 +134,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -137,10 +146,11 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -151,24 +161,38 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + // TODO(hinsu): Canonicalize QuantizeAndDequantize and + // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting + // attributes to operands. + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -177,6 +201,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -188,9 +213,17 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -198,6 +231,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -215,6 +249,11 @@ static bool IsOpAllowlisted(Operation* op) { return ops.count(abstractOp->typeID); } +namespace { + +template +using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok + static std::unique_ptr CreateDeviceMgr( const std::string& device_type) { // Register compilation kernels for all registered XLA backends. @@ -492,12 +531,14 @@ tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand, class Tf2XlaRewritePattern : public RewritePattern { public: + // Set benefit to 0 (= least benefit) so this pattern is only used as a + // fallback. explicit Tf2XlaRewritePattern(const std::string& device_type) - : RewritePattern(1, MatchAnyOpTypeTag()), device_type_(device_type) {} + : RewritePattern(0, MatchAnyOpTypeTag()), device_type_(device_type) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override { - if (!IsOpAllowlisted(op)) return failure(); + if (!IsOpAllowedTf2XlaFallback(op)) return failure(); return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_); } @@ -527,8 +568,7 @@ class LegalizeTF : public PassWrapper { // global device type for all TensorFlow ops. Option device_type_{ *this, "device-type", - llvm::cl::desc("XLA device type for execution of TensorFlow ops. " - "Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")}; + llvm::cl::desc("XLA device type for execution of TensorFlow ops.")}; }; static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 832bad2dcc8..ef362d95b97 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -34,6 +35,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" @@ -134,6 +136,11 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // MLIR LHLO. class XlaHloToLhloPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: XlaHloToLhloPass() = default; XlaHloToLhloPass(const XlaHloToLhloPass&) {} @@ -182,7 +189,10 @@ template StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( HloInstruction* instr) { Location loc = getLocation(instr); - ArrayRef> attrs; + std::pair attrs[] = { + {Identifier::get("name", builder_.getContext()), + builder_.getStringAttr(instr->name())}, + }; ArrayRef rets{}; llvm::SmallVector operands; @@ -252,15 +262,14 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { return Status::OK(); } -StatusOr LhloDialectEmitter::EmitSortOp( - HloInstruction* instr) { +StatusOr LhloDialectEmitter::EmitSortOp(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr); sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( *sort_instr->called_computations()[0], &sort.comparator(), &builder_)); - return sort.getOperation(); + return sort; } Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { @@ -327,19 +336,17 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, // create another view to adjust the slice for the shape of the instruction. Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr, SmallVectorImpl* values) { - // In terms of cache key, we have several choices: - // * Use `instr`. It's the easiest, but it creates different cache entries for - // aliased buffers, which could have been deduplicated. - // * Use the actual content as the key, aka a tree of allocation slices. - // * Somewhere in the middle, use the allocation slice for the instruction. If - // `instr` is a tuple, the key is the allocated buffer for the tuple itself - // (an array of pointers). + // Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have + // gone fancier to do the following cacheing: + // %range = ViewOp(%allocation, %offset) : memref + // %typed_range = ViewOp(%range) : memref // - // We choose the third approach for simplicity. - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(instr)); - SliceKey slice_key(slice.allocation(), slice.offset(), slice.size()); - auto result = slices_.try_emplace(slice_key, llvm::SmallVector{}); + // where %range is cached. This in theory gives easier time for alias + // analysis, since the identity of %range defines alias. However, + // %typed_range can't be cached, as different buffers with different types and + // shapes may still alias. Creating two ViewOps doesn't seem to worth the + // effort for a slightly easier aliasing, so we don't over optimize here. + auto result = slices_.try_emplace(instr, llvm::SmallVector{}); llvm::SmallVectorImpl& new_values = result.first->second; if (result.second) { ::xla::ShapeIndex shape_index; @@ -439,7 +446,7 @@ Status LhloDialectEmitter::Initialize() { builder_.setInsertionPointToEnd(block); auto return_op = builder_.create(builder_.getUnknownLoc()); - builder_ = mlir::OpBuilder(return_op); + builder_ = OpBuilder(return_op); return Status::OK(); } @@ -450,6 +457,9 @@ std::unique_ptr> createXlaHloToLhloWithXlaPass() { Status HloToLhloModule(const BufferAssignment& assignment, const HloModule& hlo_module, ModuleOp module) { + module.getContext() + ->loadDialect(); HloComputation* computation = hlo_module.entry_computation(); LhloDialectEmitter emitter(assignment, *computation, module); @@ -463,15 +473,14 @@ Status HloToLhloModule(const BufferAssignment& assignment, return computation->AcceptOrdered(&emitter, ordering); } -mlir::OwningModuleRef HloTextToLhloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context) { +OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, + MLIRContext* context) { StatusOr> maybe_module = xla::ParseAndReturnUnverifiedModule( absl::string_view(input.data(), input.size())); TF_CHECK_OK(maybe_module.status()); - mlir::OwningModuleRef module = - mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context)); TF_CHECK_OK( ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host")); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index bdc977616b1..89514116254 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -41,7 +42,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { builder_(module.getContext()), i8_type_(builder_.getIntegerType(8)) {} - ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); private: template @@ -86,9 +87,9 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // (see below). llvm::DenseMap allocations_; - // This map provides access to MLIR buffers for each HLO instruction, keyed by - // its buffer slice. A slice is contained in a BufferAllocation, and has an - // offset and a size. + // This map provides access to MLIR buffers for each HLO instruction, keyed + // instruction identity. A slice is contained in a BufferAllocation, and has + // an offset and a size. // // As for why we don't use HloInstruction*, see GetOrCreateView(), but mostly // we want to leverage better of the aliased buffers. @@ -101,8 +102,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // // `slices_` is populated lazily in the `GetOrCreateView()` helper as we // process every instruction. - using SliceKey = std::tuple; - llvm::DenseMap> slices_; + llvm::DenseMap> + slices_; // The BufferAssignment computed by XLA ahead of time. const ::xla::BufferAssignment& assignment_; diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 8850581f0bd..45166941620 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -36,8 +36,13 @@ namespace mhlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion = false, bool legalize_chlo = true); + bool allow_partial_conversion = false, bool legalize_chlo = true, + llvm::Optional tf2xla_fallback_device_type = llvm::None); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. @@ -53,6 +58,9 @@ void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, void PopulateLegalizeTfPatterns(MLIRContext* context, OwningRewritePatternList* patterns); +/// Checks whether the op is supported by the Tf2Xla fallback for legalization. +bool IsOpAllowedTf2XlaFallback(Operation* op); + /// Lowers from TF dialect's control flow to HLO dialect's control flow. std::unique_ptr> createLegalizeTFControlFlowPass(); @@ -60,8 +68,14 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, - bool legalize_chlo = true); +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. +LogicalResult legalizeTF( + Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true, + llvm::Optional tf2xla_fallback_device_type = llvm::None); // Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication // ops. diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index afc36916348..b725f56b455 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -43,47 +43,41 @@ using xla::ShapeUtil; namespace xla { PrimitiveType TypeToPrimitiveType(mlir::Type type) { - switch (type.getKind()) { - case mlir::StandardTypes::BF16: - return PrimitiveType::BF16; - case mlir::StandardTypes::Complex: { - mlir::Type element_ty = type.cast().getElementType(); - switch (element_ty.getKind()) { - case mlir::StandardTypes::F32: - return PrimitiveType::C64; - case mlir::StandardTypes::F64: - return PrimitiveType::C128; - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; - } + if (type.isBF16()) { + return PrimitiveType::BF16; + } else if (type.isF16()) { + return PrimitiveType::F16; + } else if (type.isF32()) { + return PrimitiveType::F32; + } else if (type.isF64()) { + return PrimitiveType::F64; + } else if (auto complex_type = type.dyn_cast()) { + mlir::Type element_ty = complex_type.getElementType(); + if (element_ty.isF32()) { + return PrimitiveType::C64; + + } else if (element_ty.isF64()) { + return PrimitiveType::C128; } - case mlir::StandardTypes::F16: - return PrimitiveType::F16; - case mlir::StandardTypes::F32: - return PrimitiveType::F32; - case mlir::StandardTypes::F64: - return PrimitiveType::F64; - case mlir::StandardTypes::Integer: { - const auto integer = type.cast(); - bool is_unsigned = integer.isUnsigned(); - switch (integer.getWidth()) { - case 1: - return PrimitiveType::PRED; - case 8: - return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; - case 16: - return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; - case 32: - return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; - case 64: - return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; - } + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } else if (auto integer_type = type.dyn_cast()) { + bool is_unsigned = integer_type.isUnsigned(); + switch (integer_type.getWidth()) { + case 1: + return PrimitiveType::PRED; + case 8: + return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; + case 16: + return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; + case 32: + return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; + case 64: + return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; + default: + return PrimitiveType::PRIMITIVE_TYPE_INVALID; } - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; } + return PrimitiveType::PRIMITIVE_TYPE_INVALID; } StatusOr TypeToShape( @@ -108,108 +102,89 @@ Shape TypeToShape(mlir::Type type) { if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(ptype, {}); - switch (type.getKind()) { - case mlir::StandardTypes::BF16: - case mlir::StandardTypes::F32: - case mlir::StandardTypes::F64: - case mlir::StandardTypes::Integer: { - auto* context = type.getContext(); - mlir::emitError(mlir::UnknownLoc::get(context)) - << "lowering should have been handled by primitive type lowering for " - << debugString(type); - break; + if (type.isBF16() || type.isF32() || type.isF64() || + type.isa()) { + auto* context = type.getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "lowering should have been handled by primitive type lowering for " + << debugString(type); + } else if (auto v = type.dyn_cast()) { + llvm::SmallVector span(v.getShape().begin(), v.getShape().end()); + mlir::Type element_type = v.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + } else if (auto m = type.dyn_cast()) { + llvm::SmallVector span(m.getShape().begin(), m.getShape().end()); + mlir::Type element_type = m.getElementType(); + // Treat a memref of a vector as if it was a memref of primitive type with + // the vector dimensions at the end. + if (auto v = element_type.dyn_cast()) { + element_type = v.getElementType(); + span.insert(span.end(), v.getShape().begin(), v.getShape().end()); } - case mlir::StandardTypes::Vector: { - const auto v = type.cast(); - llvm::SmallVector span(v.getShape().begin(), - v.getShape().end()); - mlir::Type element_type = v.getElementType(); - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) - return ShapeUtil::MakeShape(primitive_type, span); - break; - } - case mlir::StandardTypes::MemRef: { - const auto m = type.cast(); - llvm::SmallVector span(m.getShape().begin(), - m.getShape().end()); - mlir::Type element_type = m.getElementType(); - // Treat a memref of a vector as if it was a memref of primitive type with - // the vector dimensions at the end. - if (auto v = element_type.dyn_cast()) { - element_type = v.getElementType(); - span.insert(span.end(), v.getShape().begin(), v.getShape().end()); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {}; + // For the primitive type case, the shape of the memref is similar to the + // vector type case (i.e., it is, modulo the layout, the same dimensions + // and primitive type). + if (m.getAffineMaps().empty()) + return ShapeUtil::MakeShape(primitive_type, span); + + if (m.getAffineMaps().size() == 1) { + llvm::SmallVector strides; + int64_t offset; + if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + + llvm::SmallVector, 4> strides_with_indices; + for (const auto& e : llvm::enumerate(strides)) { + strides_with_indices.push_back({e.value(), e.index()}); } - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) break; - // For the primitive type case, the shape of the memref is similar to the - // vector type case (i.e., it is, modulo the layout, the same dimensions - // and primitive type). - if (m.getAffineMaps().empty()) - return ShapeUtil::MakeShape(primitive_type, span); + std::sort(strides_with_indices.begin(), strides_with_indices.end()); - if (m.getAffineMaps().size() == 1) { - llvm::SmallVector strides; - int64_t offset; - if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + llvm::SmallVector minor_to_major; + int64_t stride = 1; + for (const auto& pr : strides_with_indices) { + minor_to_major.push_back(pr.second); - llvm::SmallVector, 4> strides_with_indices; - for (const auto& e : llvm::enumerate(strides)) { - strides_with_indices.push_back({e.value(), e.index()}); - } - std::sort(strides_with_indices.begin(), strides_with_indices.end()); + // Either the affine map is not perfectly strided, or the dimensions + // recovered from strides don't match the actual dimensions in shapes. + if (stride != pr.first) return {}; - llvm::SmallVector minor_to_major; - int64_t stride = 1; - for (const auto& pr : strides_with_indices) { - minor_to_major.push_back(pr.second); - - // Either the affine map is not perfectly strided, or the dimensions - // recovered from strides don't match the actual dimensions in shapes. - if (stride != pr.first) return {}; - - stride *= m.getShape()[pr.second]; - } - - llvm::SmallVector dimensions(m.getShape().begin(), - m.getShape().end()); - return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, - minor_to_major); + stride *= m.getShape()[pr.second]; } - break; + + llvm::SmallVector dimensions(m.getShape().begin(), + m.getShape().end()); + return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, + minor_to_major); } - case mlir::StandardTypes::RankedTensor: { - // TODO(jpienaar): This is only handling the base case with primitive - // element type. - const auto t = type.cast(); - llvm::SmallVector span(t.getShape().begin(), - t.getShape().end()); - // Only fully static shapes are supported. - // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. - if (std::find(t.getShape().begin(), t.getShape().end(), -1) != - t.getShape().end()) - break; - mlir::Type element_type = t.getElementType(); - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - // Only primitive element type supported. - if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) - return ShapeUtil::MakeShape(primitive_type, span); - break; + } else if (auto t = type.dyn_cast()) { + // TODO(jpienaar): This is only handling the base case with primitive + // element type. + llvm::SmallVector span(t.getShape().begin(), t.getShape().end()); + // Only fully static shapes are supported. + // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. + if (std::find(t.getShape().begin(), t.getShape().end(), -1) != + t.getShape().end()) + return {}; + mlir::Type element_type = t.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + // Only primitive element type supported. + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + } else if (auto tuple_type = type.dyn_cast()) { + llvm::SmallVector shapes; + shapes.reserve(tuple_type.size()); + for (mlir::Type sub_type : tuple_type.getTypes()) { + shapes.push_back(TypeToShape(sub_type)); } - case mlir::StandardTypes::Tuple: { - const auto t = type.cast(); - llvm::SmallVector shapes; - shapes.reserve(t.size()); - for (mlir::Type sub_type : t.getTypes()) { - shapes.push_back(TypeToShape(sub_type)); - } - return ShapeUtil::MakeTupleShape(shapes); - } - case mlir::mhlo::HLOTypes::Token: - return ShapeUtil::MakeTokenShape(); - default: - break; + return ShapeUtil::MakeTupleShape(shapes); + + } else if (type.isa()) { + return ShapeUtil::MakeTokenShape(); } + // Return empty XLA shape to signify error. No MLIR Type maps to a empty // Shape. return {}; diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 158671a6242..4ad44d1bd77 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -17,11 +17,15 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -30,19 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" -// NOLINTNEXTLINE -static llvm::cl::opt emit_use_tuple_arg( - "emit-use-tuple-args", - llvm::cl::desc( - "Emit HLO modules using tuples as args for the entry computation"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt emit_return_tuple( - "emit-return-tuple", - llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), - llvm::cl::init(false)); - namespace xla { namespace { @@ -173,11 +164,17 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( } // namespace xla +static void RegisterInputDialects(mlir::DialectRegistry& registry) { + registry.insert(); +} + static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate( - "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction); + "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction, + RegisterInputDialects); static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( - "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction); + "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction, + RegisterInputDialects); static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc new file mode 100644 index 00000000000..bfe4ed3844f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" + +// NOLINTNEXTLINE +llvm::cl::opt emit_use_tuple_arg( + "emit-use-tuple-args", + llvm::cl::desc( + "Emit HLO modules using tuples as args for the entry computation"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt emit_return_tuple( + "emit-return-tuple", + llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h new file mode 100644 index 00000000000..1d5a29a5fdb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ + +#include "llvm/Support/CommandLine.h" + +// This file contains command-line options aimed to provide the parameters +// required by the MLIR module to XLA HLO conversion. It is only intended to be +// included by binaries. + +extern llvm::cl::opt emit_use_tuple_arg; +extern llvm::cl::opt emit_return_tuple; + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a3134fc1c94..30b8a7e5561 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -265,6 +265,7 @@ tf_xla_py_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -283,6 +284,7 @@ tf_xla_py_test( name = "cholesky_op_test", size = "medium", srcs = ["cholesky_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -347,6 +349,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["searchsorted_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -389,6 +392,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -411,6 +415,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -429,6 +434,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -469,7 +475,6 @@ tf_xla_py_test( enable_mlir_bridge = True, python_version = "PY3", tags = [ - "many_xla_args", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", ], @@ -533,6 +538,7 @@ tf_xla_py_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -632,6 +638,7 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -688,6 +695,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -783,6 +791,7 @@ tf_xla_py_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -821,6 +830,7 @@ tf_xla_py_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -858,6 +868,7 @@ tf_xla_py_test( size = "medium", timeout = "long", srcs = ["matrix_diag_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -927,6 +938,7 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1005,6 +1017,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1031,6 +1044,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1113,6 +1127,7 @@ tf_xla_py_test( name = "reverse_ops_test", size = "medium", srcs = ["reverse_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1164,6 +1179,7 @@ tf_xla_py_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1182,6 +1198,7 @@ tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1204,6 +1221,7 @@ tf_xla_py_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 3, tags = [ @@ -1279,6 +1297,7 @@ tf_xla_py_test( name = "stateless_random_ops_test", size = "medium", srcs = ["stateless_random_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1562,6 +1581,7 @@ tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1685,6 +1705,7 @@ tf_cuda_cc_test( deps = [ "//tensorflow/cc:cc_ops", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_kernel_creator", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu", @@ -1883,6 +1904,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 4bd2dfd9244..41877d39381 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -61,7 +60,7 @@ class CholeskyOpTest(xla_test.XLATestCase): dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): chol = linalg_ops.cholesky(placeholder) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) def testBasic(self): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 0202c582ef3..08aad66abe1 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -65,7 +64,8 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): with self.test_scope(): x = linalg_ops.matrix_triangular_solve( placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) - verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + verification = test_util.matmul_without_tf32( + placeholder_ca, x, adjoint_a=adjoint) self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol) @@ -135,6 +135,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): self._VerifyTriangularSolve( a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) + @test_util.disable_mlir_bridge("Error handling") def testNonSquareCoefficientMatrix(self): rng = np.random.RandomState(0) for dtype in self.float_types: @@ -145,6 +146,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): linalg_ops.matrix_triangular_solve(a, b) @test_util.run_v2_only # Different error types + @test_util.disable_mlir_bridge("Error handling") def testWrongDimensionsV2(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: @@ -156,6 +158,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): linalg_ops.matrix_triangular_solve(lhs, rhs) @test_util.run_v1_only("Different error types") + @test_util.disable_mlir_bridge("Error handling") def testWrongDimensionsV1(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 5fcf254db82..f396e61f3d1 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -24,12 +24,18 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + "XLA QR op calls matmul. Also, matmul used for verification. Also with " + 'TensorFloat-32, mysterious "Unable to launch cuBLAS gemm" error ' + "occasionally occurs") +# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): def AdjustedNorm(self, x): @@ -73,7 +79,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): with self.session() as sess: x_tf = array_ops.placeholder(dtype) - with self.test_scope(): + with self.device_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 9f963110cf3..0f19affc8e3 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -63,9 +63,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 7c36f8b13ca..440b7672d98 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): @@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index 343969c40d7..239b99de19e 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -25,7 +25,9 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.client import device_lib +from tensorflow.python.compat import compat from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -156,6 +158,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): def testNewStateThreeFry(self): """Tests that the new state is correct (for ThreeFry). """ + if compat.forward_compatible(2020, 10, 25): + self.skipTest("The expected values in this test is inconsistent with " + "CPU/GPU. testXLAEqualsCPU has the correct checks of the " + "new states for the new version.") with ops.device(xla_device_name()): counter = 57 key = 0x1234 @@ -171,6 +177,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): def testNewStatePhilox(self): """Tests that the new state is correct (for Philox). """ + if compat.forward_compatible(2020, 10, 25): + self.skipTest("The expected values in this test is inconsistent with " + "CPU/GPU. testXLAEqualsCPU has the correct checks of the " + "new states for the new version.") with ops.device(xla_device_name()): counter_low = 57 counter_high = 283 @@ -204,13 +214,39 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): """Tests that XLA and CPU kernels generate the same integers.""" seed = 1234 shape = [315, 49] - with ops.device("/device:CPU:0"): - cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) - .uniform_full_int(shape=shape, dtype=dtype)) - with ops.device(xla_device_name()): - xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) - .uniform_full_int(shape=shape, dtype=dtype)) - self.assertAllEqual(cpu, xla) + if compat.forward_compatible(2020, 10, 25): + with ops.device("/device:CPU:0"): + cpu_gen = random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX) + with ops.device(xla_device_name()): + xla_gen = random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX) + # Repeat multiple times to make sure that the state after + # number-generation are the same between CPU and XLA. + for _ in range(5): + with ops.device("/device:CPU:0"): + # Test both number-generation and skip + cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype) + cpu_gen.skip(100) + with ops.device(xla_device_name()): + xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype) + xla_gen.skip(100) + self.assertAllEqual(cpu, xla) + self.assertAllEqual(cpu_gen.state, xla_gen.state) + else: + # The old version doesn't guarantee that CPU and XLA are in the same state + # after number-generation, which is a bug. + with ops.device("/device:CPU:0"): + cpu = ( + random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int( + shape=shape, dtype=dtype)) + with ops.device(xla_device_name()): + xla = ( + random.Generator.from_seed( + seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int( + shape=shape, dtype=dtype)) + self.assertAllEqual(cpu, xla) def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -364,4 +400,5 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): if __name__ == "__main__": ops.enable_eager_execution() + config.set_soft_device_placement(False) test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index f9d792806b0..23e827f18e8 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -21,7 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla +from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.kernel_tests.random import util as \ random_test_util from tensorflow.python.ops import array_ops @@ -39,6 +43,26 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): allowed_types.update({dtypes.int32, dtypes.int64}) return self.all_tf_types & allowed_types + @test_util.run_v2_only + def testForcedCompile(self): + """Tests whole-function forced-compilation. + + This test checks that stateless_random_* can be used in forced-compilation + scenarios (e.g. TPU). The new version of stateless_random_* requires the + intermediate tensor `alg` to be compile-time constant, so we need to check + that this requirement is met. We use xla.compile instead of tf.function's + experimental_compile because the latter doesn't throw an error even if the + compile-time-constant constraint is not met. + """ + if config.list_logical_devices('TPU'): + self.skipTest('To accommodate OSS, xla.compile support for TPU is not ' + 'linked in.') + @def_function.function + def f(x): + return xla.compile( + lambda x: stateless.stateless_random_normal([], seed=x), [x]) + f([1, 2]) + def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) with self.session(), self.test_scope(): @@ -138,7 +162,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark): def _benchmarkUniform(self, name, dtype, use_xla_jit): - def BuilderFn(): + def builder_fn(): shape = (10, 1000, 1000) seed_var = variables.Variable((312, 456), dtype=dtypes.int32, @@ -147,7 +171,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark): shape, seed=seed_var, dtype=dtype) return '%s.shape%s' % (name, shape), [random_t] - xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu') def benchmarkUniformF32(self): self._benchmarkUniform( @@ -167,4 +191,5 @@ class StatelessRandomOpsBenchmark(test.Benchmark): if __name__ == '__main__': + config.set_soft_device_placement(False) test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 7bbfecff403..4109fdc64a5 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -214,7 +214,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -252,7 +251,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc index 569261de094..0e40c497c24 100644 --- a/tensorflow/compiler/tests/unary_ops_composition_test.cc +++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/synchronization/notification.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -43,6 +44,11 @@ limitations under the License. namespace tensorflow { namespace { +static bool Initialized = [] { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + return true; +}(); + class UnaryOpsCompositionTest : public OpsTestBase { protected: template diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index eb022da6895..b5f82bcff12 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -96,7 +96,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllEqual(result, expected) @test_util.disable_mlir_bridge( - "MlirHloBuilder::Iota missing required for xla::Diag") + "Handle complex element type in DiagPart lowering") def testAllTypeOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( @@ -538,8 +538,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-40, 40], dtype=dtype), expected=np.array([1.0, 0.025], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") def testQuantizeAndDequantize(self): for dtype in self.float_types: @@ -1070,8 +1068,6 @@ class UnaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.DepthToSpace compilation") def testDepthToSpace(self): def make_op(data_format): @@ -1118,14 +1114,12 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), - expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], - [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], - [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], - [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]], + [[16, 17, 18, 19], [24, 25, 26, 27]]], + [[[4, 5, 6, 7], [12, 13, 14, 15]], + [[20, 21, 22, 23], [28, 29, 30, 31]]]]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.SpaceToDepth compilation") def testSpaceToDepth(self): def make_op(data_format): @@ -1172,11 +1166,11 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), - expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], - [[[4, 5, 6, 7, 20, 21, 22, 23]]], - [[[8, 9, 10, 11, 24, 25, 26, 27]]], - [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], - dtype=dtype)) + expected=np.array( + [[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]], + [[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]], + [[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]], + dtype=dtype)) def _assertSoftplusMatchesExpected(self, features, diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py index 1e30ebd55d0..304405c82ce 100644 --- a/tensorflow/compiler/tests/xla_device_gpu_test.py +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -27,6 +28,10 @@ from tensorflow.python.platform import test class XlaDeviceGpuTest(test.TestCase): + def __init__(self, method_name="runTest"): + super(XlaDeviceGpuTest, self).__init__(method_name) + context.context().enable_xla_devices() + def testCopiesToAndFromGpuWork(self): """Tests that copies between GPU and XLA devices work.""" if not test.is_gpu_available(): diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 0d6ae81ef6e..3e9f5e8c5dd 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -79,6 +79,25 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(v,), expected=np.tile(v, (7, 42, 1, 1))) + @test_util.disable_mlir_bridge('Not supported yet') + def testGather(self): + operand = np.arange(10, dtype=np.int32).reshape([2, 5]) + start_indices = np.array([2], np.int32) + slice_sizes = np.array([1, 3], np.int32) + + def gather(operand, start_indices): + dimension_numbers = xla_data_pb2.GatherDimensionNumbers() + dimension_numbers.offset_dims.extend([1]) + dimension_numbers.collapsed_slice_dims.extend([0]) + dimension_numbers.start_index_map.extend([0]) + dimension_numbers.index_vector_dim = 1 + return xla.gather(operand, start_indices, dimension_numbers, slice_sizes) + + self._assertOpOutputMatchesExpected( + gather, + args=(operand, start_indices), + expected=np.array([[5, 6, 7]])) + @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightLogical(self): self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 3b057ed8b17..de97c6ff210 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -83,6 +83,8 @@ class XLATestCase(test.TestCase): def __init__(self, method_name='runTest'): super(XLATestCase, self).__init__(method_name) + if 'XLA' in FLAGS.test_device: + context.context().enable_xla_devices() context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled() self.device = FLAGS.test_device @@ -235,8 +237,8 @@ class XLATestCase(test.TestCase): 'test_session not supported on XLATestCase, please use session') @contextlib.contextmanager - def test_scope(self): - """Test scope that runs tests on `self.device`. + def device_scope(self): + """Scope that runs tests on `self.device`. Yields: A scope to apply to the operators under test. @@ -244,6 +246,15 @@ class XLATestCase(test.TestCase): with ops.device('device:{}:0'.format(self.device)): yield + def test_scope(self): + """Deprecated alias of `device_scope`. + + This should be avoided as the name starts with `test`, so test runners + treat it as a test. This interferes with class decorators that operate on + each test method. + """ + return self.device_scope() + def Benchmark(tf_bench, builder_fn, diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 0718bd8cd65..44fb5513886 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -11,7 +11,6 @@ load( "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", - "tf_gpu_kernel_library", ) # buildifier: disable=same-origin-load @@ -81,6 +80,7 @@ tf_cuda_cc_test( cc_library( name = "common_utils", + srcs = ["common/utils.cc"], hdrs = ["common/utils.h"], copts = tf_copts(), deps = [ @@ -539,20 +539,6 @@ tf_cuda_cc_test( ], ) -tf_gpu_kernel_library( - name = "plugin_cast", - srcs = ["plugin/plugin_cast.cu.cc"], - deps = [ - ":trt_plugins", - "@com_google_absl//absl/strings", - "//tensorflow/core/platform:logging", - "//tensorflow/core:framework_lite", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:tensorrt", - ]), -) - tf_cuda_library( name = "trt_plugins", srcs = ["plugin/trt_plugin.cc"], @@ -602,6 +588,7 @@ pybind_extension( link_in_framework = True, module_name = "_pywrap_py_utils", deps = [ + ":common_utils", ":py_utils", "//tensorflow/core/platform:env", "//tensorflow/core/platform:logging", diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc new file mode 100644 index 00000000000..6679ca04513 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "absl/base/call_once.h" +#include "absl/strings/str_join.h" +#include "third_party/tensorrt/NvInferPlugin.h" +#endif + +namespace tensorflow { +namespace tensorrt { + +std::tuple GetLinkedTensorRTVersion() { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + return std::tuple{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, + NV_TENSORRT_PATCH}; +#else + return std::tuple{0, 0, 0}; +#endif +} + +std::tuple GetLoadedTensorRTVersion() { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + int ver = getInferLibVersion(); + int major = ver / 1000; + ver = ver - major * 1000; + int minor = ver / 100; + int patch = ver - minor * 100; + return std::tuple{major, minor, patch}; +#else + return std::tuple{0, 0, 0}; +#endif +} + +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +namespace tensorflow { +namespace tensorrt { +namespace { + +void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { + LOG(INFO) << "Linked TensorRT version: " + << absl::StrJoin(GetLinkedTensorRTVersion(), "."); + LOG(INFO) << "Loaded TensorRT version: " + << absl::StrJoin(GetLoadedTensorRTVersion(), "."); + + bool plugin_initialized = initLibNvInferPlugins(trt_logger, ""); + if (!plugin_initialized) { + LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may " + "fail later."; + } + + int num_trt_plugins = 0; + nvinfer1::IPluginCreator* const* trt_plugin_creator_list = + getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); + if (!trt_plugin_creator_list) { + LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry."; + } else { + VLOG(1) << "Found the following " << num_trt_plugins + << " TensorRT plugins in registry:"; + for (int i = 0; i < num_trt_plugins; ++i) { + if (!trt_plugin_creator_list[i]) { + LOG_WARNING_WITH_PREFIX + << "TensorRT plugin at index " << i + << " is not accessible (null pointer returned by " + "getPluginCreatorList for this plugin)"; + } else { + VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); + } + } + } +} + +} // namespace + +void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { + static absl::once_flag once; + absl::call_once(once, InitializeTrtPlugins, trt_logger); +} + +} // namespace tensorrt +} // namespace tensorflow +#endif diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.h b/tensorflow/compiler/tf2tensorrt/common/utils.h index b428733ecd4..b76b75de783 100644 --- a/tensorflow/compiler/tf2tensorrt/common/utils.h +++ b/tensorflow/compiler/tf2tensorrt/common/utils.h @@ -16,15 +16,33 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ +#include + +namespace tensorflow { +namespace tensorrt { +// Returns the compile time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLinkedTensorRTVersion(); + +// Returns the runtime time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLoadedTensorRTVersion(); +} // namespace tensorrt +} // namespace tensorflow + #if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/core/platform/logging.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { #define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: " +// Initializes the TensorRT plugin registry if this hasn't been done yet. +void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c4fc3e4f5da..2804a381e0c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -733,6 +733,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; segment_options.use_implicit_batch = params.use_implicit_batch; + if (segment_options.use_implicit_batch) + segment_options.maximum_batch_size = params.max_batch_size; segment_options.allow_dynamic_non_batch_dim = AllowDynamicNonBatchDimension(params); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 3b0553426c0..be3bb51dbed 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -151,7 +151,8 @@ TEST(ConvertGraphTest, GetDeviceAndAllocator) { class ConvertAfterShapesTest : public ::testing::Test { public: - Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def) { + Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def, + int maximum_batch_size = 1000) { // Create GraphProperties. grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); @@ -162,6 +163,7 @@ class ConvertAfterShapesTest : public ::testing::Test { const std::vector output_names{"output"}; ConversionParams params; params.output_names = &output_names; + params.max_batch_size = maximum_batch_size; params.max_workspace_size_bytes = 8 << 20; params.output_graph_def = output_graph_def; params.minimum_segment_size = 1; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index dc5acbb4f50..c0c3f25177e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1197,42 +1197,6 @@ Status TrtNodeValidator::ConvertConstToWeights( return status; } -static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { - static mutex plugin_mutex(LINKER_INITIALIZED); - static bool plugin_initialized = false; - mutex_lock lock(plugin_mutex); - if (plugin_initialized) return; - - LOG(INFO) << "Linked TensorRT version: " << GetLinkedTensorRTVersion(); - LOG(INFO) << "Loaded TensorRT version: " << GetLoadedTensorRTVersion(); - - plugin_initialized = initLibNvInferPlugins(trt_logger, ""); - if (!plugin_initialized) { - LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may " - "fail later."; - } - - int num_trt_plugins = 0; - nvinfer1::IPluginCreator* const* trt_plugin_creator_list = - getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); - if (!trt_plugin_creator_list) { - LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry."; - } else { - VLOG(1) << "Found the following " << num_trt_plugins - << " TensorRT plugins in registry:"; - for (int i = 0; i < num_trt_plugins; ++i) { - if (!trt_plugin_creator_list[i]) { - LOG_WARNING_WITH_PREFIX - << "TensorRT plugin at index " << i - << " is not accessible (null pointer returned by " - "getPluginCreatorList for this plugin)"; - } else { - VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); - } - } - } -} - // static StatusOr> Converter::Create( TrtPrecisionMode precision_mode, bool use_calibration, @@ -1249,7 +1213,7 @@ Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration, : precision_mode_(precision_mode), use_calibration_(use_calibration), use_implicit_batch_(use_implicit_batch) { - InitializeTrtPlugins(trt_logger); + MaybeInitializeTrtPlugins(trt_logger); this->RegisterOpConverters(); } @@ -1434,7 +1398,8 @@ Status Converter::BuildCudaEngine( TF_RETURN_IF_ERROR( TrtPrecisionModeToName(precision_mode_, &precision_mode_str)); string trt_network_name = StrCat( - "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-", + "TF:", TF_VERSION_STRING, ", ", + "TRT:", absl::StrJoin(GetLoadedTensorRTVersion(), "."), "-", "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_, ", ", "Max-Batch-Size:", max_batch_size, ", ", "Max-Workspace-Size:", max_workspace_size_bytes); @@ -2410,6 +2375,40 @@ Status ConvertTranspose(OpConverterParams* params) { return Status::OK(); } +Status ConvertShape(OpConverterParams* params) { + const auto& inputs = params->inputs; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", TrtInputArg::kBoth}})); + if (params->use_implicit_batch) { + return errors::Unimplemented( + "Shape is only supported for explicit batch mode."); + } + if (HasStaticShape(inputs.at(0).GetTrtDims())) { + if (params->validation_only) return Status::OK(); + nvinfer1::Dims input_dims = inputs.at(0).GetTrtDims(); + nvinfer1::Dims output_dims{1, {input_dims.nbDims}}; + // Create a const node with the values of output_dims + TRT_ShapedWeights weight = params->weight_store->GetTempWeights( + nvinfer1::DataType::kINT32, output_dims); + int32* values_ptr = static_cast(weight.GetValues()); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, values_ptr); + auto output = params->converter->CreateConstantLayer(weight, output_dims); + params->outputs->push_back(TRT_TensorOrWeights(output)); + return Status::OK(); + } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + if (params->validation_only) return Status::OK(); + nvinfer1::IShapeLayer* shape_layer = + params->converter->network()->addShape(*inputs.at(0).tensor()); + TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0))); + return Status::OK(); +#else + return errors::Unavailable( + "Shape op conversion requires TensorRT 6 or above"); +#endif +} + Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; TF_RETURN_IF_ERROR( @@ -3749,6 +3748,7 @@ Status ConvertActivation(OpConverterParams* params) { params->converter->network()->addActivation(*inputs.at(0).tensor(), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); // Set parameters. #if IS_TRT_VERSION_GE(5, 1, 2, 0) if (node_def.op() == "Elu") { @@ -3849,9 +3849,10 @@ Status ConvertRelu6(OpConverterParams* params) { nvinfer1::IActivationLayer* layer = params->converter->network()->addActivation( *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setAlpha(0.0f); layer->setBeta(6.0f); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4407,6 +4408,7 @@ Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(*tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Set quantization ranges. @@ -4484,7 +4486,7 @@ Status ConvertReduce(OpConverterParams* params) { int trt_axis; TF_RETURN_IF_ERROR( ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims, - node_def.name(), /*use_implicit_batch=*/true, &trt_axis)); + node_def.name(), params->use_implicit_batch, &trt_axis)); axes |= (1 << trt_axis); } @@ -5055,6 +5057,7 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { combined_scale_weights.GetTrtWeights(), dummy_power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5974,6 +5977,7 @@ static void RegisterValidatableOpConverters( (*registration)[pool_op_type] = ConvertPool3D; } #endif + (*registration)["Shape"] = ConvertShape; (*registration)["Rsqrt"] = ConvertRsqrt; (*registration)["Slice"] = ConvertSlice; (*registration)["Softmax"] = ConvertSoftmax; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 29eb24d2316..b127337e02a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1709,12 +1709,12 @@ class ParameterizedOpConverterTestBase std::tuple> { public: ParameterizedOpConverterTestBase() - : trt_mode(std::get<0>(GetParam())), - tf_type(std::get<1>(GetParam())), - converter_precision(std::get<2>(GetParam())) {} + : trt_mode_(std::get<0>(GetParam())), + tf_type_(std::get<1>(GetParam())), + converter_precision_(std::get<2>(GetParam())) {} void Reset() { - OpConverterTest::Reset(converter_precision, trt_mode); + OpConverterTest::Reset(converter_precision_, trt_mode_); input_data_.clear(); } @@ -1750,7 +1750,7 @@ class ParameterizedOpConverterTestBase if (!partial_input_shape_dims.empty()) { partial_shape = partial_input_shape_dims; } else { - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In dynamic shape mode we make all dims unknown. partial_shape = std::vector(dims.size(), -1); } else { @@ -1776,7 +1776,7 @@ class ParameterizedOpConverterTestBase void AddTestTensor(const string& name, const std::vector& dims, const std::vector& values = {}, const std::vector& partial_input_shape_dims = {}) { - AddTestTensor(name, dims, tf_type, values, partial_input_shape_dims); + AddTestTensor(name, dims, tf_type_, values, partial_input_shape_dims); } // Builds and runs the converted network. Checks output tensor shape. Tests @@ -1785,7 +1785,8 @@ class ParameterizedOpConverterTestBase void BuildAndRun(const string& name, const std::vector>& expected_output_dims, const Status& expected_runtime_status, - const std::vector>>& matcher) { + const std::vector>>& matcher, + const std::vector& out_tf_types = {}) { TensorShape shape; const int n_output = expected_output_dims.size(); ASSERT_EQ(n_output, matcher.size()); @@ -1794,12 +1795,14 @@ class ParameterizedOpConverterTestBase TF_EXPECT_OK( TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); string out_name = (n_output == 1) ? name : StrCat(name, ":", i); - InputOutputData data{out_name, - ConstructTensor(shape.num_elements(), 0, tf_type)}; + DataType out_tf_type = + out_tf_types.size() > i ? out_tf_types[i] : tf_type_; + InputOutputData data{ + out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)}; output_data.push_back(data); } - ASSERT_FALSE(input_data_.empty()); - const int batch_size = input_data_[0].tensor.shape().dim_size(0); + const int batch_size = + input_data_.empty() ? 1 : input_data_[0].tensor.shape().dim_size(0); Status stat = OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); ASSERT_EQ(expected_runtime_status.ok(), stat.ok()) @@ -1824,20 +1827,22 @@ class ParameterizedOpConverterTestBase const std::vector& expected_output_dims, const Status& expected_conversion_status, const Status& expected_runtime_status, - const Matcher>& matcher) { + const Matcher>& matcher, + const std::vector& out_tf_types = {}) { RunValidationAndConversion(node_def, expected_conversion_status, name.c_str(), expected_output_dims); if (expected_conversion_status.ok()) { BuildAndRun(name, std::vector>({expected_output_dims}), expected_runtime_status, - std::vector>>({matcher})); + std::vector>>({matcher}), + out_tf_types); } } protected: - const TrtTestMode trt_mode; - const DataType tf_type; - const TrtPrecisionMode converter_precision; + const TrtTestMode trt_mode_; + const DataType tf_type_; + const TrtPrecisionMode converter_precision_; DataVec input_data_; }; @@ -2070,7 +2075,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { 37.342354, 41.013527, 30.9738, 34.469433, 45.018955, 48.59309, 59.369415, 63.04059}; for (auto get_node_def : get_node_def_vec) { - NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0); + NodeDef tmp_node_def = get_node_def(tf_type_, "NCHW", true, 0); std::string op_name = tmp_node_def.op(); std::vector test_param{ {"NHWC", 0, false, 0, @@ -2092,7 +2097,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { errors::Unimplemented(StrCat("The input \"variance\" for ", op_name, " must be a constant, at my_batchnorm"))}, {"NCHW", 0, false, 0.01}}; // The last one is the only test that runs. - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { test_param.push_back( {"NCHW", 0, false, 0.01, errors::InvalidArgument( @@ -2102,7 +2107,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { for (auto p : test_param) { Reset(); NodeDef node_def = - get_node_def(tf_type, p.data_format, p.is_training, p.epsilon); + get_node_def(tf_type_, p.data_format, p.is_training, p.epsilon); for (int i = 0; i < node_input.size(); i++) { if (i == 0 || i == p.tensor_input_idx) { // The first input (x) is always added as a tensor, and it hase shape @@ -2121,7 +2126,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { // the first arg is a tensor. TODO(tfeher) Check if one can relax this // restriction. Status expected_status = - (i != 0 && trt_mode == TrtTestMode::kImplicitBatch) + (i != 0 && trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::InvalidArgument( StrCat("Batch size doesn't match for tensor ", node_input[i].name, @@ -2129,19 +2134,19 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { "converter batch size: 3 vs 2")) : Status::OK(); std::vector partial_input_shape; - if (i == 0 && trt_mode == TrtTestMode::kDynamicShape && + if (i == 0 && trt_mode_ == TrtTestMode::kDynamicShape && !p.keep_channel_unknown) { // keep channel dim static (known) partial_input_shape.resize(4, -1); partial_input_shape[1] = node_input[i].dims[1]; } - AddTestTensor(node_input[i].name, node_input[i].dims, tf_type, + AddTestTensor(node_input[i].name, node_input[i].dims, tf_type_, node_input[i].val, partial_input_shape, expected_status); } else { AddTestWeights(node_input[i].name, node_input[i].dims, - node_input[i].val, tf_type); + node_input[i].val, tf_type_); } } TestOpConverter("my_batchnorm", node_def, node_input[0].dims, @@ -2149,12 +2154,12 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { ArrayFloatNear(expected_output)); } } -} // namespace convert +} TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); const NodeDef& node_def = transpose.operation.node()->def(); @@ -2182,13 +2187,13 @@ TEST_P(OpConverterTest1, ConvertTranspose) { {}, {3, 2, 1, 1}, {3, 2, 1, 0}, - (trt_mode == TrtTestMode::kImplicitBatch) + (trt_mode_ == TrtTestMode::kImplicitBatch) ? Status(error::UNIMPLEMENTED, "Transpose at batch dimension is not supported") : Status::OK()}, TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}}, }; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Dynamic shape tests where some shapes are known test_params.push_back(TestParamBase{ {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}}); @@ -2309,6 +2314,55 @@ TEST_F(OpConverterTest, ConvertReshape) { } } +TEST_P(OpConverterTest1, ConvertShape) { + // Get the NodeDef for Shape op. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); + auto shape = ops::Shape(s.WithOpName("my_shape"), input); + const NodeDef& node_def = shape.operation.node()->def(); + + Status conversion_status = + (trt_mode_ == TrtTestMode::kImplicitBatch) + ? errors::Unimplemented( + "Shape is only supported for explicit batch mode.") + : Status::OK(); + std::vector test_params = { +// TODO(b/166274212): Enable the test parameter for TensorRT 7.1.3. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status}, +#endif + // Add input as weight (we use non empty param ({1}) to trigger this). + TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status}, + }; + + auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); }; + for (auto p : test_params) { + SCOPED_TRACE(p); + Reset(); + // The number of elements of the input tensor. We leave it 0 in case we do + // not need to add an input tensor. This happens in explicit batch mode: the + // shape is known at conversion time and therefore the shape is added to the + // network as a constant layer. In this case the single node network that + // we use for the unit test have no actual input tensor when it is converted + // to a TensorRT network. + int n_elements = 0; + if (input_is_weight(p) || trt_mode_ != TrtTestMode::kExplicitBatch) { + // Calculate the number of elements for adding input data. + n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1, + std::multiplies()); + } + std::vector input_val(n_elements, 1); + if (!input_is_weight(p)) { + AddTestTensor("input", p.input_dims, input_val); + } else { + AddTestWeights("input", p.input_dims, input_val, tf_type_); + } + TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray(p.input_dims), + {DT_INT32}); + } +} + // Helper function for testing MatMul and BatchMatMul // get_matmul corresponds to the function used to generate the node. It should // accept (DataType, transpose_a, transpose_b) as parameters. @@ -2566,7 +2620,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (const string& data_format : {"NHWC", "NCHW"}) { for (const int trt_input_rank : {1, 2, 3, 4}) { Reset(); - NodeDef node_def = get_biasadd_nodedef(data_format, tf_type); + NodeDef node_def = get_biasadd_nodedef(data_format, tf_type_); // Add input, dims_array will be like {2, 1, ..., 1, 3} std::vector dims_array(trt_input_rank + 1, 1); @@ -2588,7 +2642,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (int i = 0; i < channel_size; ++i) { bias[i] = i + 1; // bias will be {1, 2, 3, ...} } - AddTestWeights("weights", {channel_size}, bias, tf_type); + AddTestWeights("weights", {channel_size}, bias, tf_type_); // Build and run the engine. std::vector output_data; @@ -2624,7 +2678,7 @@ NodeDef GetBinaryOpNodeDef(DataType dtype) { TEST_P(OpConverterTest2, ConvertBinary) { { AttrValue dtype; - dtype.set_type(tf_type); + dtype.set_type(tf_type_); // Both inputs are weights. Reset(); NodeDef node_def = @@ -2669,19 +2723,19 @@ TEST_P(OpConverterTest2, ConvertBinary) { if (!op_test_info.count(op_name)) { FAIL() << "Binary op test map does not contain op " << op_name; } - NodeDef node_def = op_test_info[op_name].first(tf_type); + NodeDef node_def = op_test_info[op_name].first(tf_type_); std::vector input_names; std::vector> input_dims; std::vector> input_values; if (operand_1_is_tensor) { AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6}); } else { - AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type); + AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type_); } if (operand_2_is_tensor) { AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3}); } else { - AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type); + AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type_); } TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(), Status::OK(), @@ -2888,10 +2942,10 @@ TEST_P(OpConverterTest2, ConvertSquare) { // Input is weights, should fail. Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type_); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Square must be a tensor, at my_square"); @@ -2900,7 +2954,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); @@ -2913,7 +2967,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { inputs[i] = value; expected_outputs[i] = value * value; } - AddTestTensor("input", {1, 1, 20}, tf_type, inputs); + AddTestTensor("input", {1, 1, 20}, tf_type_, inputs); TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(), ArrayFloatNear(expected_outputs, 0)); @@ -3040,7 +3094,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - const NodeDef& node_def = CreateUnaryOp(tf_type); + const NodeDef& node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -3097,7 +3151,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { FAIL() << "Activation op test map does not contain op " << op_name; } Reset(); - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); const std::vector input = {-100, -2, -1, 0, 1, 88}; AddTestTensor("input", p.input_dims, input); @@ -3125,7 +3179,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { TEST_P(OpConverterTest1, ConvertExpandDims) { // Get the NodeDef for ExpandDims. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); @@ -3153,7 +3207,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {0}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3162,7 +3216,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {-5}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3200,7 +3254,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { } TEST_P(OpConverterTest1, ConvertSqueeze) { - const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); + const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch); // Get the NodeDef for Squeeze. auto get_squeeze_nodedef = [](std::vector axes, DataType tf_type) -> NodeDef { @@ -3223,7 +3277,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { {}, // input partial dims {2, 3}, // expected output dims {}, // axis - trt_mode == TrtTestMode::kExplicitBatch + trt_mode_ == TrtTestMode::kExplicitBatch ? Status::OK() : Status{error::UNIMPLEMENTED, "Squeeze is not implemented for empty squeeze_dims, at " @@ -3282,7 +3336,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { "Dimension 2 with size 2 cannot be squeezed because it must be " "size 1, at my_squeeze"}}; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In this test we try to squeeze axis=2 which has size > 1. In dynamic // shape mode the converter sees only -1, so it cannot catch this error. squeeze_non_singleton.status = Status::OK(); // conversion status @@ -3297,7 +3351,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { for (TestParamBase p : test_params) { SCOPED_TRACE(p); Reset(); - NodeDef node_def = get_squeeze_nodedef(p.param, tf_type); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_type_); AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, p.partial_input_dims); TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, @@ -4052,14 +4106,14 @@ TEST_F(OpConverterTest, ConvertSlice) { TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. - DataType tf_type_loc = tf_type; + DataType tf_type = tf_type_; auto get_conv2d_nodedef = - [tf_type_loc](std::vector strides = {1, 1, 1, 1}, - string padding = "SAME", string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type_loc); - auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type_loc); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -4152,12 +4206,12 @@ TEST_P(OpConverterTest1, ConvertConv2D) { node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { Reset(); NodeDef node_def = get_conv2d_nodedef(); // Channel dim unknown, should fail. AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, - TfDataTypeToTrt(tf_type)); + TfDataTypeToTrt(tf_type_)); AddTestWeights("weights", {1, 2, 1, 1}, {-1, 1}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4179,8 +4233,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { // Ok. std::vector ok_params = { -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // Basic TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4192,9 +4244,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, -#endif -// TODO(b/162448349): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // SAME padding (Asymmetric) TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4217,9 +4266,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, -#endif -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // NHWC TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4253,7 +4299,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, -#endif }; for (int i = 0; i < ok_params.size(); i++) { @@ -4262,15 +4307,15 @@ TEST_P(OpConverterTest1, ConvertConv2D) { get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); std::vector partial_input_shape; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // The channel dim cannot have unknown size, fix that. partial_input_shape.resize(ok_params[i].input_dims.size(), -1); int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; } - AddTestTensor("input", ok_params[i].input_dims, tf_type, ok_params[i].input, - partial_input_shape); + AddTestTensor("input", ok_params[i].input_dims, tf_type_, + ok_params[i].input, partial_input_shape); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); @@ -4797,7 +4842,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (int nDim : test_nDims) { // Input is weights, should fail. Reset(); - NodeDef node_def = get_pool_nodedef(tf_type, nDim); + NodeDef node_def = get_pool_nodedef(tf_type_, nDim); AddTestWeights("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -4906,7 +4951,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (bool is_max_pooling : {true, false}) { Reset(); NodeDef node_def = - get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding, + get_pool_nodedef(tf_type_, nDim, ksize, strides, p.padding, data_format, is_max_pooling); AddTestTensor("input", input_dims, input); TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(), @@ -4968,7 +5013,7 @@ TEST_F(OpConverterTest, ConvertTopK) { TEST_P(OpConverterTest3, ConvertGather) { // Get the NodeDef for GatherV2. Scope s = Scope::NewRootScope(); - auto params = ops::Placeholder(s.WithOpName("params"), tf_type); + auto params = ops::Placeholder(s.WithOpName("params"), tf_type_); auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); @@ -4976,7 +5021,7 @@ TEST_P(OpConverterTest3, ConvertGather) { { // Axis is a tensor, should fail. Reset(); - AddTestTensor("params", {1, 1, 2, 3}, tf_type, {}); + AddTestTensor("params", {1, 1, 2, 3}, tf_type_, {}); AddTestTensor("indices", {1, 2}, DT_INT32, {}); AddTestTensor("axis", {1}, DT_INT32, {}); RunValidationAndConversion( @@ -5021,7 +5066,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 1, 3}, /*expected_output=*/{4, 5, 6, 1, 2, 3}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the" " batch dimension, at my_gather"} @@ -5034,7 +5079,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2, 1}, /*expected_output=*/{3, 1, 6, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "Indices must have a batch size of 1 when params" " is a tensor."} @@ -5048,7 +5093,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2}, /*expected_output=*/{2, 3, 5, 6}, /*params_is_tensor=*/false, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input axis must be zero when params is a" " weight."} @@ -5061,13 +5106,13 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2}, /*expected_output=*/{2, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch // conversion_status + trt_mode_ == TrtTestMode::kImplicitBatch // conversion_status ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_gather"} : Status::OK(), - Status::OK(), // runtime_status - trt_mode == TrtTestMode::kImplicitBatch // add_index_status + Status::OK(), // runtime_status + trt_mode_ == TrtTestMode::kImplicitBatch // add_index_status ? Status{error::INVALID_ARGUMENT, "Batch size doesn't match for tensor indices: " "Provided batch size does not match converter " @@ -5182,7 +5227,7 @@ TEST_P(OpConverterTest3, ConvertGather) { if (p.params_is_tensor) { AddTestTensor("params", p.params_shape, params_input); } else { - AddTestWeights("params", p.params_shape, params_input, tf_type); + AddTestWeights("params", p.params_shape, params_input, tf_type_); } AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {}, p.add_index_status); @@ -5192,6 +5237,150 @@ TEST_P(OpConverterTest3, ConvertGather) { } } +template +NodeDef CreateReduceOp(DataType tf_type, bool keep_dims) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + typename OpType::Attrs op_attrs; + op_attrs.keep_dims_ = keep_dims; + auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs); + return op.operation.node()->def(); +} + +// Applies reduction op on sub-sequences of input +// output[i] = reduce(input[m * i : m * (i +1)]) +std::vector CalcReduce(string op_name, std::vector input, int m, + float (*op)(float, float), float init) { + std::vector output(input.size() / m); + for (int i = 0; i < output.size(); i++) { + auto begin = input.begin() + i * m; + auto end = input.begin() + (i + 1) * m; + output[i] = std::accumulate(begin, end, init, op); + if (op_name == "Mean") { + output[i] /= m; + } + } + return output; +} +TEST_P(OpConverterTest1, ConvertReduce) { + { + // Input is weights, should fail. + Reset(); + const NodeDef node_def = CreateReduceOp(tf_type_, false); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Sum must be a tensor, at my_reduce"); + } + { + // Axis is weights, should fail. + Reset(); + const NodeDef node_def = CreateReduceOp(tf_type_, false); + AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + AddTestTensor("axis", {1}, DT_INT32, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for Sum must be a constant, at my_reduce"); + } + using OpFunc = std::function; + using ValFunc = float (*)(float, float); + struct ReduceTestDescriptor { + string name; + OpFunc get_node; + ValFunc val_func; + float init_val; + }; + std::vector op_test_info{ + {"Sum", CreateReduceOp, [](float x, float y) { return x + y; }, + 0}, + {"Prod", CreateReduceOp, + [](float x, float y) { return x * y; }, 1}, + {"Mean", CreateReduceOp, + [](float x, float y) { return x + y; }, 0}, + {"Min", CreateReduceOp, + [](float x, float y) { return y < x ? y : x; }, 1000}, + {"Max", CreateReduceOp, + [](float x, float y) { return x < y ? y : x; }, -1000}}; + + std::vector input_values{1, 2, 3, 4, 5, 6}; + struct TestParams { + std::vector input_dims; + std::vector input_values; + // Helper array contains the same elements as input but permuted in a way + // that the reduction can be calculated over contiguous elements using + // CalcReduce + std::vector helper_array; + std::vector axis; + int stride; // product of input_dims along axis + Status conversion_status; + }; + std::vector params{ + // Out of range tests + TestParams{{2, 3, 1}, input_values, input_values, {3}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {-4}, 3}, + // Ok tests + TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {0}, 2}, + TestParams{{2, 3, 1}, input_values, input_values, {1}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {2}, 1}, + TestParams{{2, 3, 1}, input_values, input_values, {0, 1}, 6}, + // Ok tests with negative axis values + TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {-3}, 2}, + TestParams{{2, 3, 1}, input_values, input_values, {-2}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {-1}, 1}, + TestParams{{2, 3, 1}, input_values, input_values, {-3, 1}, 6}, + }; + + for (bool keep_dims : {false, true}) { + for (auto& op : op_test_info) { + for (auto p : params) { + SCOPED_TRACE(StrCat(op.name, keep_dims ? "keep_dims" : "")); + Reset(); + NodeDef node_def = op.get_node(tf_type_, keep_dims); + + AddTestTensor("input", p.input_dims, p.input_values); + AddTestWeights("axis", {static_cast(p.axis.size())}, + p.axis); + std::vector expected_output_dims(p.input_dims); + + // Set expected output dim and conversion error messages + for (int ax : p.axis) { + int rank = p.input_dims.size(); + if (ax >= rank || ax < -rank) { + p.conversion_status = + errors::InvalidArgument("Axis value of ", ax, + " is out of bounds, must be in " + "range [", + -rank, ", ", rank, "), at my_reduce"); + } else { + int ax_positive = ax >= 0 ? ax : ax + rank; + // Zero marks elements that we will remove later. + expected_output_dims[ax_positive] = keep_dims ? 1 : 0; + if (trt_mode_ == TrtTestMode::kImplicitBatch && + (ax == 0 || ax == -rank)) { + p.conversion_status = errors::Unimplemented( + "TensorRT does not allow manipulation of the batch " + "dimension, at my_reduce"); + } + } + } + expected_output_dims.erase(std::remove(expected_output_dims.begin(), + expected_output_dims.end(), 0), + expected_output_dims.end()); + VLOG(2) << "out dims " + << absl::StrCat("[", absl::StrJoin(expected_output_dims, ","), + "]"); + std::vector expected_values = CalcReduce( + op.name, p.helper_array, p.stride, op.val_func, op.init_val); + TestOpConverter("my_reduce", node_def, expected_output_dims, + p.conversion_status, Status::OK(), + ArrayFloatNear(expected_values)); + } + } + } +} + NodeDef CreateCastOp(DataType tf_type) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); @@ -5204,7 +5393,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. Reset(); - const NodeDef node_def = CreateUnaryOp(tf_type); + const NodeDef node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -5260,7 +5449,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { if (!op_map.count(op_name)) { FAIL() << "Unary op test map does not contain op " << op_name; } - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for // now. Need to find a better way to express input and output types. @@ -5268,7 +5457,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { // TODO(tfeher): improve tests by defining an expected output data type and // check that. Currently only the shape and values of the output are // checked. - DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type; + DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type_; std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; AddTestTensor("input", p.input_dims, input_tf_type, input_values); @@ -5835,7 +6024,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/1, /*expected_output_dims=*/{1, 2, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input \"values_1\" for Pack must be a tensor, at " "my_pack"} @@ -5861,7 +6050,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/-4, /*expected_output_dims=*/{2, 1, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the batch " "dimension, at my_pack"} @@ -5921,7 +6110,7 @@ TEST_P(OpConverterTest2, ConvertPack) { }, }; // Inputs have inconsistent shapes, should fail. - if (trt_mode != TrtTestMode::kDynamicShape) { + if (trt_mode_ != TrtTestMode::kDynamicShape) { params.push_back(TestParams{ /*input_shapes=*/{{1, 2, 3}, {1, 3, 2}}, /*partial_input_shapes=*/{{}, {}}, @@ -5941,7 +6130,7 @@ TEST_P(OpConverterTest2, ConvertPack) { // TODO(tfeher) Add dynamic shapes test once TRT handles shape error // decently } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Test with mixed dynamic / static shape input tensors params.push_back( TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, @@ -5957,14 +6146,14 @@ TEST_P(OpConverterTest2, ConvertPack) { const int num_inputs = p.input_shapes.size(); EXPECT_EQ(num_inputs, p.input_values.size()); - NodeDef node_def = GetPackNodeDef(tf_type, num_inputs, p.axis); + NodeDef node_def = GetPackNodeDef(tf_type_, num_inputs, p.axis); // Create inputs. for (int j = 0; j < num_inputs; ++j) { if (j == 1 && p.input_1_is_weight) { AddTestWeights(StrCat("values_", j), p.input_shapes[j], - p.input_values[j], tf_type); + p.input_values[j], tf_type_); } else { - AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type, + AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type_, p.input_values[j], p.partial_input_shapes[j]); } } @@ -6492,7 +6681,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { { // Input is a weight, should fail. Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestTensor("y", {1, 1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -6519,7 +6708,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { /*value_y=*/std::vector(7 * 5, 0), /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/common_input, - trt_mode == TrtTestMode::kDynamicShape + trt_mode_ == TrtTestMode::kDynamicShape ? Status::OK() : errors::InvalidArgument("Infeasible broadcast scheme"), errors::Internal( @@ -6545,7 +6734,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { for (auto p : params) { Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestTensor("x", p.dims_x, p.value_x); AddTestTensor("y", p.dims_y, p.value_y); TestOpConverter("my_squared_diff", node_def, p.expected_output_dims, @@ -6581,7 +6770,7 @@ template void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; - std::vector> params{ + std::vector> params { // TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x. #if !IS_TRT_VERSION_GE(7, 1, 3, 0) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a69960005fc..1fc0d13c993 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -241,36 +241,6 @@ int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { #endif -string GetLinkedTensorRTVersion() { - int major, minor, patch; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - major = NV_TENSORRT_MAJOR; - minor = NV_TENSORRT_MINOR; - patch = NV_TENSORRT_PATCH; -#else - major = 0; - minor = 0; - patch = 0; -#endif - return absl::StrCat(major, ".", minor, ".", patch); -} - -string GetLoadedTensorRTVersion() { - int major, minor, patch; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - int ver = getInferLibVersion(); - major = ver / 1000; - ver = ver - major * 1000; - minor = ver / 100; - patch = ver - minor * 100; -#else - major = 0; - minor = 0; - patch = 0; -#endif - return absl::StrCat(major, ".", minor, ".", patch); -} - absl::string_view GetDeviceName(const Node* node) { if (node->has_assigned_device_name()) { return node->assigned_device_name(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index a0505c3f922..7570dff1c9d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -117,14 +117,6 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); -// Returns a string that includes compile time TensorRT library version -// information {Maj, Min, Patch}. -string GetLinkedTensorRTVersion(); - -// Returns a string that includes runtime time TensorRT library version -// information {Maj, Min, Patch}. -string GetLoadedTensorRTVersion(); - // Returns true if an engine built for cached_shapes can also run actual_shapes. bool AreShapesCompatible(const std::vector& actual_shapes, const std::vector& cached_shapes); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 58d1c611463..5b2ae822d59 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -800,6 +800,9 @@ StatusOr> TRTEngineOp::GetEngine( TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); infer->setGpuAllocator(allocator); + // Need to initialize plugins in order to deserialize engines that contain + // plugins. + MaybeInitializeTrtPlugins(&logger); TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), serialized_segment_.size(), nullptr)); diff --git a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc b/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc deleted file mode 100644 index 141a7d1f462..00000000000 --- a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" -#include "tensorflow/core/platform/logging.h" - -#if GOOGLE_CUDA && GOOGLE_TENSORRT -#define EIGEN_USE_GPU // For definition of Eigen::GpuDevice. -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "third_party/tensorrt/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { -using nvinfer1::DataType; -using nvinfer1::Dims; -using nvinfer1::IPluginCreator; -using nvinfer1::IPluginV2; -using nvinfer1::IPluginV2Ext; -using nvinfer1::PluginField; -using nvinfer1::PluginFieldCollection; -using nvinfer1::PluginFieldType; -using nvinfer1::PluginFormat; - -template -__global__ void Cast(const SrcT* input, int num_elements, DstT* output) { - for (int i : CudaGridRangeX(num_elements)) { - output[i] = static_cast(input[i]); - } -} - -template -void RunCast(const SrcT* d_input, int num_elements, DstT* d_output, - cudaStream_t stream) { - const int threads_per_block = 256; - const int blocks_per_grid = - (num_elements + threads_per_block - 1) / threads_per_block; - TF_CHECK_OK(CudaLaunchKernel(Cast, threads_per_block, - blocks_per_grid, 0, stream, d_input, - num_elements, d_output)); -} - -const char* kPluginName = "TfTrtPluginCast"; - -class CastPlugin : public TrtPlugin { - public: - CastPlugin(DataType src_type, DataType dst_type) - : src_type_(src_type), dst_type_(dst_type) {} - - CastPlugin(const void* serialized_data, size_t length) - : TrtPlugin(serialized_data, length) { - const char* buffer = static_cast(serialized_data); - src_type_ = ReadFromBuffer(&buffer); - dst_type_ = ReadFromBuffer(&buffer); - src_dims_ = ReadFromBuffer(&buffer); - } - - CastPlugin(const CastPlugin& rhs) - : TrtPlugin(rhs), - src_type_(rhs.src_type_), - dst_type_(rhs.dst_type_), - src_dims_(rhs.src_dims_) {} - - // Methods from IPluginV2Ext. - - DataType getOutputDataType(int index, const DataType* input_types, - int num_inputs) const override { - DCHECK_EQ(0, index); - DCHECK_EQ(1, num_inputs); - return dst_type_; - } - - bool isOutputBroadcastAcrossBatch(int output_index, - const bool* input_is_broadcasted, - int num_inputs) const override { - return false; - } - - bool canBroadcastInputAcrossBatch(int input_index) const override { - return false; - } - - void configurePlugin(const Dims* input_dims, int num_inputs, - const Dims* output_dims, int num_outputs, - const DataType* input_types, - const DataType* output_types, - const bool* input_is_broadcast, - const bool* output_is_broadcast, - PluginFormat float_format, int max_batch_size) override { - DCHECK_EQ(1, num_inputs); - DCHECK_EQ(1, num_outputs); - DCHECK(src_type_ == input_types[0]); - DCHECK(dst_type_ == output_types[0]); - src_dims_ = input_dims[0]; - } - - IPluginV2Ext* clone() const override { return new CastPlugin(*this); } - - // Methods from IPluginV2. - - const char* getPluginType() const override { return kPluginName; }; - - const char* getPluginVersion() const override { return kTfTrtPluginVersion; }; - - int getNbOutputs() const override { return 1; } - - Dims getOutputDimensions(int index, const Dims* inputs, - int num_input_dims) override { - DCHECK_EQ(0, index); - DCHECK_EQ(1, num_input_dims); - return inputs[0]; - } - - bool supportsFormat(DataType type, PluginFormat format) const override { - return type == DataType::kFLOAT || type == DataType::kINT32; - } - - size_t getWorkspaceSize(int max_batch_size) const override { return 0; } - - int enqueue(int batch_size, const void* const* inputs, void** outputs, void*, - cudaStream_t stream) override { - int num_elements = batch_size; - for (int i = 0; i < src_dims_.nbDims; i++) { - num_elements *= src_dims_.d[i]; - } - const void* input = inputs[0]; - void* output = outputs[0]; - DCHECK_NE(static_cast(src_type_), static_cast(dst_type_)); - - switch (src_type_) { - case DataType::kFLOAT: - RunCast(reinterpret_cast(input), num_elements, - reinterpret_cast(output), stream); - break; - case DataType::kINT32: - RunCast(reinterpret_cast(input), num_elements, - reinterpret_cast(output), stream); - break; - default: - return 1; // Indicates a failure. - } - return 0; - } - - size_t getSerializationSize() const override { - return 2 * sizeof(DataType) + sizeof(Dims); - } - - void serialize(void* serialized_data) const override { - char* buffer = static_cast(serialized_data); - WriteToBuffer(src_type_, &buffer); - WriteToBuffer(dst_type_, &buffer); - WriteToBuffer(src_dims_, &buffer); - } - - private: - DataType src_type_; - DataType dst_type_; - Dims src_dims_; -}; - -class CastPluginCreator : public IPluginCreator { - public: - CastPluginCreator() { - setPluginNamespace(kTfTrtPluginNamespace); - plugin_fields_.emplace_back( - PluginField("SrcT", nullptr, PluginFieldType::kINT32, 1)); - plugin_fields_.emplace_back( - PluginField("DstT", nullptr, PluginFieldType::kINT32, 1)); - - field_collection_.nbFields = plugin_fields_.size(); - field_collection_.fields = plugin_fields_.data(); - } - - const char* getPluginName() const override { return kPluginName; } - - const char* getPluginVersion() const override { return kTfTrtPluginVersion; } - - const PluginFieldCollection* getFieldNames() override { - return &field_collection_; - } - - IPluginV2* createPlugin( - const char* name, - const PluginFieldCollection* field_collection) override { - const PluginField* fields = field_collection->fields; - DataType src_type, dst_type; - for (int i = 0; i < field_collection->nbFields; ++i) { - const char* attr_name = fields[i].name; - if (!strcmp(attr_name, "SrcT")) { - src_type = *static_cast(fields[i].data); - } else if (!strcmp(attr_name, "DstT")) { - dst_type = *static_cast(fields[i].data); - } else { - return nullptr; - } - } - return new CastPlugin(src_type, dst_type); - } - - IPluginV2* deserializePlugin(const char* name, const void* serial_data, - size_t serial_len) override { - return new CastPlugin(serial_data, serial_len); - } - - void setPluginNamespace(const char* plugin_namespace) override { - namespace_ = plugin_namespace; - } - - const char* getPluginNamespace() const override { return namespace_.c_str(); } - - private: - PluginFieldCollection field_collection_; - std::vector plugin_fields_; - std::string namespace_; -}; - -REGISTER_TFTRT_PLUGIN(CastPluginCreator); - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 1337a733f91..021e28ec6f0 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -676,6 +676,21 @@ void AddSegmentForNode(const grappler::GraphProperties* graph_properties, device_name); } +bool OpBatchSizeExceedMaximumBatchSize( + const grappler::GraphProperties* graph_properties, const Node* node, + bool use_implicit_batch, absl::optional maximum_batch_size) { + ClusterBatchSize cluster_batch_size = + GetClusterBatchSizeForNode(graph_properties, node, use_implicit_batch); + if (cluster_batch_size.HasStaticBatchValue() && + maximum_batch_size.has_value() && + cluster_batch_size.GetStaticBatchValue() > maximum_batch_size.value()) { + VLOG(2) << "OP batch size " << cluster_batch_size.GetStaticBatchValue() + << " max_batch_size " << maximum_batch_size.value(); + return true; + } + return false; +} + } // namespace Status SegmentGraph(const Graph* tf_graph, @@ -690,6 +705,10 @@ Status SegmentGraph(const Graph* tf_graph, "Explicit batch mode should allow dynamic non-batch dimensions"); } + if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) { + return errors::Internal("Implicit batch mode requires maximum_batch_size"); + } + if (!options.allow_dynamic_non_batch_dim && !graph_properties) { return errors::Internal( "Need graph propertities to disallow dynamic non-batch dimensions"); @@ -768,6 +787,14 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << ")"; exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST"); + } else if (OpBatchSizeExceedMaximumBatchSize( + graph_properties, node->tf_node(), + options.use_implicit_batch, options.maximum_batch_size)) { + LOG_WARNING_WITH_PREFIX + << "Implicit batch mode requires OP batch size not larger than " + << "the converter maximum batch size: " + << "(Op name: " << node->name() << ")"; + exclude_node("OP batch size too large"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index 3f79983cfd2..bab6e089fa4 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -38,6 +39,9 @@ struct SegmentOptions { // Segment must contain at least this many nodes. int minimum_segment_size = 2; bool use_implicit_batch = true; + // The maximum batch size used to build the engines in the graph, when + // use_implicit_batch is true. + absl::optional maximum_batch_size = absl::nullopt; // When use_implicit_batch is false or when we are building dynamic engines, // we allow dynamic non-batch dimensions. bool allow_dynamic_non_batch_dim = false; diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index bf277328fe7..ee406c9743f 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -108,8 +108,9 @@ class SegmentTest : public ::testing::Test { segment_options_.allow_dynamic_non_batch_dim = true; } - void EnableImplicitBatchModeForStaticEngine() { + void EnableImplicitBatchModeForStaticEngine(int maximum_batch_size = 1000) { segment_options_.use_implicit_batch = true; + segment_options_.maximum_batch_size = maximum_batch_size; segment_options_.allow_dynamic_non_batch_dim = false; } @@ -487,7 +488,11 @@ TEST_F(SegmentTest, TwoChainsDiffBatchSizes) { const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; EnableImplicitBatchModeForStaticEngine(); RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, - {{"output-0", "const-scalar"}}); + /*expected_segments=*/{{"output-0", "const-scalar"}}); + + EnableImplicitBatchModeForStaticEngine(1); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + /*expected_segments=*/{}); } TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) { diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index b91f5771ce5..54bbc251e4f 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -109,8 +109,9 @@ class ClusterBatchSize { return s; } - private: bool HasStaticBatchValue() const { return static_batch_value_.has_value(); } + + private: bool HasDynamicBatchValue() const { return has_dynamic_batch_value_; } private: diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index a8e24aa8983..3f8a11f7410 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -41,31 +41,5 @@ bool IsGoogleTensorRTEnabled() { #endif } -void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - *major = NV_TENSORRT_MAJOR; - *minor = NV_TENSORRT_MINOR; - *patch = NV_TENSORRT_PATCH; -#else - *major = 0; - *minor = 0; - *patch = 0; -#endif -} - -void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - int ver = getInferLibVersion(); - *major = ver / 1000; - ver = ver - *major * 1000; - *minor = ver / 100; - *patch = ver - *minor * 100; -#else - *major = 0; - *minor = 0; - *patch = 0; -#endif -} - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h index f52bb6f1bad..9b24eb36cf9 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h @@ -21,12 +21,6 @@ namespace tensorrt { bool IsGoogleTensorRTEnabled(); -// Return compile time TensorRT library version information {Maj, Min, Patch}. -void GetLinkedTensorRTVersion(int* major, int* minor, int* patch); - -// Return runtime time TensorRT library version information {Maj, Min, Patch}. -void GetLoadedTensorRTVersion(int* major, int* minor, int* patch); - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc index 03f77c6bd5f..52252f125ac 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc @@ -16,18 +16,15 @@ limitations under the License. #include #include "pybind11/pybind11.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" std::tuple get_linked_tensorrt_version() { - int major, minor, patch; - tensorflow::tensorrt::GetLinkedTensorRTVersion(&major, &minor, &patch); - return std::tuple{major, minor, patch}; + return tensorflow::tensorrt::GetLinkedTensorRTVersion(); } std::tuple get_loaded_tensorrt_version() { - int major, minor, patch; - tensorflow::tensorrt::GetLoadedTensorRTVersion(&major, &minor, &patch); - return std::tuple{major, minor, patch}; + return tensorflow::tensorrt::GetLoadedTensorRTVersion(); } PYBIND11_MODULE(_pywrap_py_utils, m) { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index d4f3a524577..a73877bc3cc 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -74,7 +74,7 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, // algorithm uses too much memory. If we don't fail immediately building the // engine can be *very* slow with TensorRT7 when GPU memory is limited. AllocationAttributes attributes; - attributes.no_retry_on_failure = true; + attributes.retry_on_failure = false; void* mem = allocator_->AllocateRaw(alignment, total_size, attributes); if (!mem) return nullptr; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc index 70a0a9a7b65..2f31865751f 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -35,14 +36,16 @@ void TrtShapeOptimizationProfile::InitProfiles() { << "for each input (min=opt=max)."; } for (auto& shape_vec : input_shapes_) { - std::vector dimvec; - for (auto& shape : shape_vec) { - dimvec.push_back(TensorShapeToTrtDims(shape, false)); + if (!shape_vec.empty()) { + std::vector dimvec(shape_vec.size()); + absl::c_transform(shape_vec, dimvec.begin(), [](TensorShape shape) { + return TensorShapeToTrtDims(shape, false); + }); + // Set min=opt=max. + OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec}; + profiles_.push_back(std::move(profConfig)); + VLOG(1) << "Created profile " << profiles_.back().DebugString(); } - // We set min=opt=max. - OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec}; - profiles_.push_back(std::move(profConfig)); - VLOG(1) << "Created profile " << profiles_.back().DebugString(); } } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1e57c11b2cf..1a91f54afc9 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -337,7 +337,6 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":frontend_attributes_util", ":host_compute_metadata_proto_cc", ":rearrange_function_argument", ":sharding_util", @@ -353,23 +352,17 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", - "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", - "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -378,11 +371,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], @@ -787,6 +777,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", @@ -828,9 +819,9 @@ cc_library( ":frontend_attributes_util", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -856,9 +847,9 @@ cc_library( ":functionalize_control_flow_util", ":functionalize_while", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -944,9 +935,9 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -1087,6 +1078,7 @@ tf_cuda_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 936b74f7b33..c7c8702b49b 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/graph/algorithm.h" @@ -217,5 +218,10 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) { EXPECT_EQ(const_args, std::vector({true})); } +static bool Initialized = [] { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + return true; +}(); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 54abccb4cfc..452b102fade 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,9 +25,10 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/graph_to_functiondef.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 10b26f9801c..2a3e35e0ffd 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" @@ -46,12 +46,254 @@ limitations under the License. namespace tensorflow { +// Helper functions for functionalizing control flow in functions. + +// Maps function name to +// - new function name, if the function body was functionalized +// - absl::nullopt, if not +using FuncMap = std::map>; +using FuncMapIter = std::map>::const_iterator; + +// Returns whether function has been processed before. +bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) { + return func_iter != func_map->end(); +} + +// Returns whether function has been modified (i.e., functionalized) before. +bool FunctionHasBeenModified(FuncMapIter func_iter) { + return func_iter->second.has_value(); +} + +// Returns a name for the new functionalized version of a function. +string GetNewFunctionName( + const string& func_name, Node* n, + AssociatedFunctionInfo::AssociatedFunctionType func_type, + FunctionLibraryDefinition* fld) { + // For SymbolicGradient, `func_name` is always "SymbolicGradient" which + // is not very informative. Use node name instead. + return ( + func_type == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient + ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")) + : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_"))); +} + +// Returns name to which a modified function has been mapped. +const string& GetMappedFunctionName(FuncMapIter func_iter) { + DCHECK(func_iter->second.has_value()); + return func_iter->second.value(); +} + +// Updates `func_map` with function given by `canonicalized_name`. +void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, + const string& new_func_name, bool function_modified) { + // If function was modified store its new name, otherwise add empty entry to + // record that function has been processed and does not need to be rewritten. + (*func_map)[canonicalized_name] = + function_modified ? absl::make_optional(new_func_name) : absl::nullopt; +} + +// Adds new function def to graph's function library if necessary. +Status AddFunctionDefToGraphLibrary( + const string& func_name, const AssociatedFunctionInfo& associated_function, + Graph* graph, FunctionLibraryDefinition* fld) { + const OpRegistrationData* op_reg_data; + // We have to be careful with adding the function def since there are three + // different `OpRegistryInterface`s involved here: + // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`. + // We have already added the function def to `fld` before calling this + // function but for the subsequent `RewriteAssociatedFunction` call we need + // the function def to be in one of the other two registries, otherwise + // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case + // because it cannot find the associated function def. + // On the other hand, we should not add the function def if it is already + // contained in one of the last two registries, this would lead to errors when + // the function def is already in one registry and we try to add it to the + // other one (if we try to add it to the same it's fine). This can happen in + // cases where one of the last two registries is identical to `fld` (which we + // already updated). + // Therefore, before adding the function def we have to check if it's already + // contained in either `graph->flib_def()` or + // `graph->flib_def().default_registry()` which is done in the following line + // (we have to use `LookUp` instead of `Contains` or `Find` because the latter + // both don't check the default registry). + if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) + return Status::OK(); + + const FunctionDef* new_fdef = fld->Find(func_name); + DCHECK(new_fdef != nullptr); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = *new_fdef; + return graph->AddFunctionLibrary(fdef_lib); +} + +// Functionalizes function given by `func_name`. Update `func_map` accordingly. +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, + const NodeFilter& node_filter = {}); + +// Functionalizes all functions that are (directly or indirectly) associated to +// any node in `graph`. Adds processed functions to `func_map`. +Status FunctionalizeControlFlowForNodeAssociatedFunctions( + FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld, + FunctionLibraryRuntime* flr, bool* any_function_modified, + const NodeFilter& node_filter) { + std::vector>> + nodes_to_associated_functions; + for (auto* n : graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (const auto& pair : nodes_to_associated_functions) { + Node* n = pair.first; + auto associated_functions = pair.second; + for (auto& associated_function : associated_functions) { + // Note that if `n` is a function call node, then potential calls of + // `RewriteAssociatedFunction` below might delete `n` and create a new + // node instead, making `n` an invalid pointer. That's fine because in + // that case `n` only has one associated function, so this loop has only + // one iteration and we don't use `n` again after the rewrite. + // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed + // below. + DCHECK(associated_function.type() != + AssociatedFunctionInfo::kFunctionCallNode || + associated_functions.size() == 1); + + // Process one node-function-pair. + string func_name = associated_function.func_name(); + string canonicalized_name = + Canonicalize(func_name, AttrSlice(&associated_function.attrs())); + auto func_iter = func_map->find(canonicalized_name); + string new_func_name; + if (FunctionHasBeenProcessed(func_iter, func_map)) { + if (FunctionHasBeenModified(func_iter)) { + *any_function_modified = true; + new_func_name = GetMappedFunctionName(func_iter); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + continue; + } + // Function is processed for the first time. + bool function_modified = false; + new_func_name = + GetNewFunctionName(func_name, n, associated_function.type(), fld); + // Perform functionalization for current function. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func_name, new_func_name, associated_function.attrs(), fld, flr, + func_map, &function_modified, node_filter)); + UpdateFunctionMap(func_map, canonicalized_name, new_func_name, + function_modified); + if (function_modified) { + *any_function_modified = true; + TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary( + new_func_name, associated_function, graph, fld)); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + } + } + return Status::OK(); +} + +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) { + *function_modified = false; + + // Convert the function to a graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + // Skip nodes that are filtered out. + if (node_filter && !node_filter(n)) continue; + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // Before functionalizing control flow in `g` we functionalize control flow + // in functions (directly or indirectly) associated with nodes in `g`. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + func_map, g, fld, flr, function_modified, node_filter)); + + if (has_switch_or_merge) { + *function_modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter)); + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } + } + if (*function_modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + } + + return ret_status; +} + Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); + if (include_functions) { + // Functionalize control flow in functions that are (directly or indirectly) + // associated with a node in `graph`. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library, + tensorflow::OptimizerOptions()); + // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice` + // (because we constructed it with `device_mgr = nullptr`). + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + FuncMap func_map; + bool modified = false; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + &func_map, graph, library, flr, &modified, node_filter)); + } // Functionalize and remove while loops from graph. TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter)); @@ -68,153 +310,19 @@ Status FunctionalizeControlFlow(Graph* graph, Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter, + include_functions)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); return Status::OK(); } -Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map>* canonicalized_name_to_new_name, - bool* modified) { - *modified = false; - - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - Graph* g = body->graph; - - // Check if the graph has Switch or Merge node. - bool has_switch_or_merge = false; - for (Node* n : body->graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // We cannot return here directly if the graph has no Switch/Merge. - // It might contain function call nodes, or If/While nodes with Switch/Merge - // in function body. We still need to rewrite those functions and modify - // corresponding nodes. - - // If any node has associated functions, functionalize them first. - // Gather nodes with associated functions first, because rewriting those nodes - // might involve node deletion/addition. Avoid modifying nodes while iterating - // it. - std::vector>> - nodes_to_associated_functions; - for (auto* n : g->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, fld); - if (!associated_functions.empty()) { - nodes_to_associated_functions.push_back({n, associated_functions}); - } - } - for (const auto& iter : nodes_to_associated_functions) { - Node* n = iter.first; - auto associated_functions = iter.second; - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = - Canonicalize(name, AttrSlice(&associated_function.attrs())); - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - string new_name; - bool function_modified; - if (iter != canonicalized_name_to_new_name->end()) { - // If we already processed this function, check if it was rewritten. If - // the function was rewritten, the entry will be non-empty. Otherwise - // the entry will be empty. - function_modified = iter->second.has_value(); - if (function_modified) { - new_name = iter->second.value(); - } - } else { - if (associated_function.type() == - AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { - // For SymbolicGradient, `name` is always "SymbolicGradient", - // which is not very informative. Use node name instead. - new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); - } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name, &function_modified)); - if (function_modified) { - // If the function was rewritten, add an non-empty entry. So later we - // know we have processed this function, and it was rewritten into - // another function. - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - } else { - // If the function was not rewritten, add an empty entry. So later - // we know we have processed this function, and it does not need to be - // rewritten. - (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; - } - } - if (function_modified) { - *modified = true; - - // Notice that if "n" is a function call, RewriteAssociatedFunction() - // will delete it and create a new node instead, making "n" an invalid - // pointer. That's fine because in that case, associated_functions will - // only have one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - g, n, fld, associated_function, new_name)); - } - } - } - - if (has_switch_or_merge) { - *modified = true; - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *g, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, - fld); - } - } - - if (*modified) { - // Add rewritten FunctionDef into library. - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } - } - - return ret_status; -} - Status FunctionalizeControlFlowForXlaPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); @@ -241,7 +349,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; - std::map> canonicalized_name_to_new_name; + FuncMap func_map; bool fld_modified = false; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); @@ -258,7 +366,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); + &func_map, &modified)); if (modified) { n->ClearAttr(func_attr); func.set_name(new_func_name); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f9e751e2d67..46abae27878 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -30,6 +30,13 @@ namespace tensorflow { // // If `node_filter` is defined, then only loops and conditions for whose // nodes `node_filter` returns true are functionalized. + +// If `include_functions` is true, then loops and conditions inside of functions +// that are associated with nodes in `graph` (e.g., a function called from a +// node in `graph`) are also functionalized, otherwise they are not. +// This also handles transitive cases, e.g., a function body will be +// functionalized when it is called in another function that is called by some +// node in `graph` (and so on). The node filter also applies here. // // Precondition: // For any node in a loop or condition for which `node_filter` returns true, @@ -43,11 +50,13 @@ namespace tensorflow { // satisfies the above conditions. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); // This pass looks at the graph, and turns V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 79a042ad680..951ebdd7ec1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -27,12 +27,15 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -63,18 +66,41 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), // lambda: math_ops.add(x, 23)) // -// Tests different node filters. -class ConditionalTestFixture : public ::testing::TestWithParam { +// Tests different node filters and functionalization inside of a function. +class ConditionalTestFixture + : public ::testing::TestWithParam> { protected: - void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); } + void SetUp() override { + restrict_to_tpu_nodes_ = std::get<0>(GetParam()); + wrap_condition_in_function_ = std::get<1>(GetParam()); + } void RunTest(); private: + void BuildCondGraph(Graph* cond_graph); + void CheckGraphDef(const GraphDef& graph_def, + const FunctionLibraryDefinition& library); + bool restrict_to_tpu_nodes_ = false; + bool wrap_condition_in_function_ = false; }; -void ConditionalTestFixture::RunTest() { - Graph graph(OpRegistry::Global()); +TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } + +INSTANTIATE_TEST_SUITE_P( + FunctionalizeControlFlow, ConditionalTestFixture, + ::testing::Combine(::testing::Bool(), ::testing::Bool()), + [](const ::testing::TestParamInfo& + info) { + bool restrict_to_tpu_nodes = std::get<0>(info.param); + bool wrap_cond_in_function = std::get<1>(info.param); + string name = + absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter", + wrap_cond_in_function ? "_in_function" : "_in_graph"); + return name; + }); + +void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -102,13 +128,117 @@ void ConditionalTestFixture::RunTest() { auto merge = ops::Merge(scope.WithOpName("cond/Merge"), std::initializer_list{add, mul}); - TF_EXPECT_OK(scope.ToGraph(&graph)); + TF_EXPECT_OK(scope.ToGraph(cond_graph)); // Set `_tpu_replicate` attribute for all nodes. - for (Node* n : graph.nodes()) { + for (Node* n : cond_graph->nodes()) { n->AddAttr("_tpu_replicate", "cluster"); } } +} + +void ConditionalTestFixture::CheckGraphDef( + const GraphDef& graph_def, const FunctionLibraryDefinition& library) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = + ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, then_fn, + else_fn, ops::If::OutputShapes({PartialTensorShape()})); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +void ConditionalTestFixture::RunTest() { + Graph graph(OpRegistry::Global()); + if (wrap_condition_in_function_) { + // Wrap condition in a function which is called from `graph`. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + + Graph cond_graph(OpRegistry::Global()); + BuildCondGraph(&cond_graph); + + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef)); + + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = cond_fdef; + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + NodeDef cond_fn; + cond_fn.set_name("cond_node"); + cond_fn.set_op("cond_fn"); + *(cond_fn.add_input()) = "source"; + Status status; + scope.graph()->AddNode(cond_fn, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } else { + // Build condition in `graph`. + BuildCondGraph(&graph); + } + FunctionLibraryDefinition library(graph.flib_def()); // If `restrict_to_tpu_nodes_` is true let filter function return true for // `_tpu_replicate` nodes. NodeFilter node_filter = @@ -116,99 +246,47 @@ void ConditionalTestFixture::RunTest() { ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } : NodeFilter{}; - FunctionLibraryDefinition library(OpRegistry::Global(), {}); GraphDef optimized_graph_def; graph.ToGraphDef(&optimized_graph_def); - TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def, - &library, node_filter)); - TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter)); - GraphDef converted_graph_def; - graph.ToGraphDef(&converted_graph_def); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &optimized_graph_def, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); + TF_ASSERT_OK(FunctionalizeControlFlow( + &graph, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); - for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + if (wrap_condition_in_function_) { + // Check if function body was functionalized. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library, + tensorflow::OptimizerOptions()); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + FunctionLibraryRuntime::Handle handle; - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = - ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, then_fn, - else_fn, ops::If::OutputShapes({PartialTensorShape()})); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Functionalized function name is the type string of `cond_node`. + string func_name; + for (Node* n : graph.nodes()) { + if (n->name() == "cond_node") { + func_name = n->type_string(); + break; + } } + TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle)); + const FunctionBody* body = flr->GetFunctionBody(handle); + GraphDef graph_def; + body->graph->ToGraphDef(&graph_def); + CheckGraphDef(graph_def, library); + } else { + // Check if graphs were functionalized. + CheckGraphDef(optimized_graph_def, library); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + CheckGraphDef(converted_graph_def, library); } } -TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } - -INSTANTIATE_TEST_SUITE_P( - FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(), - [](const ::testing::TestParamInfo& - info) { return info.param ? "with_filter" : "without_filter"; }); - // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index dce5efe5557..79412c4abc8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 1a26f974989..02f178f9acf 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -139,5 +140,11 @@ TEST(FusedBatchnormReserveSpaceTest, Test) { test::ExpectClose(results[0], results[1], /*atol=*/1e-4); test::ExpectClose(results[2], results[3], /*atol=*/1e-4); } + +static bool Initialized = [] { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + return true; +}(); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 26051c98cb7..0edd918a92d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -108,6 +108,7 @@ tf_kernel_library( "stack_ops.cc", "stateful_random_ops.cc", "stateless_random_ops.cc", + "stateless_random_ops_v2.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tensor_list_ops.cc", @@ -187,6 +188,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:stateful_random_ops_header", + "//tensorflow/core/kernels:stateless_random_ops_v2_header", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index d7a8e67dd33..807c061b60f 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -28,13 +29,26 @@ class BroadcastToOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape(0); TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + auto output_status_or = + BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output_status_or.status()); + auto output = output_status_or.ValueOrDie(); + std::vector dynamic_dims; + OP_REQUIRES_OK( + context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); + for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) { + if (dynamic_dims[dim]) { + output = xla::SetDimensionSize( + output, + xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}), + {}), + dim); + } + } - auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); - OP_REQUIRES_OK(context, output.status()); - context->SetOutput(0, output.ValueOrDie()); + context->SetOutput(0, output); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index 7ac38369eb4..ad94c1383f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -63,36 +63,27 @@ class DequantizeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { DataType input_type = ctx->input_type(0); - double minrange, maxrange; - - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &minrange)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &maxrange)); - - float min_range = static_cast(minrange); - float max_range = static_cast(maxrange); - float full_range, half_range; + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output = xla::ConvertElementType(input, xla::F32); + xla::XlaOp min_range = xla::ConvertElementType(ctx->Input(1), xla::F32); + xla::XlaOp max_range = xla::ConvertElementType(ctx->Input(2), xla::F32); + xla::XlaOp full_range; + xla::XlaOp half_range; if (input_type == DT_QINT8) { - full_range = get_fullrange(); - half_range = (full_range + 1.0f) / 2.0f; + full_range = ScalarLike(output, get_fullrange()); + half_range = + (full_range + ScalarLike(output, 1.0f)) / ScalarLike(output, 2.0f); } else { OP_REQUIRES(ctx, input_type == DT_QUINT8, errors::InvalidArgument( "Only support DT_QINT8 or DT_QUINT8, got ", input_type)); - full_range = get_fullrange(); - half_range = 0.0f; + full_range = ScalarLike(output, get_fullrange()); + half_range = ScalarLike(output, 0.0f); } - float scale_factor = (max_range - min_range) / full_range; + xla::XlaOp scale = (max_range - min_range) / full_range; - xla::XlaOp input = ctx->Input(0); - xla::XlaOp output; - - output = xla::ConvertElementType(input, xla::F32); - - auto scale = ScalarLike(output, scale_factor); - auto halfrange = ScalarLike(output, half_range); - output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale), - ScalarLike(output, min_range)); + output = xla::Add(xla::Mul(xla::Add(output, half_range), scale), min_range); if (dtype_ == DT_BFLOAT16) { output = xla::ConvertElementType(output, xla::BF16); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 19aa85f9d42..b4b18dd2b36 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -49,7 +49,8 @@ class GatherOp : public XlaOpKernel { bool indices_are_sorted_; }; -REGISTER_XLA_OP(Name("XlaGather"), GatherOp); +REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"), + GatherOp); class ScatterOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc index 46585a26769..71920372cde 100644 --- a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc @@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel { }; void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) { - ctx->SetOutput(0, xla::ReplicaId(ctx->builder())); + ctx->SetOutput( + 0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32)); } REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index bf9a9150ea6..213045e428a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -108,32 +110,73 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; + auto input_xla_shape = ctx->InputXlaShape(0); + if (input_xla_shape->is_static()) { + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); + return; + } + // Handing dynamic reshapes if input contains a dynamic dimension. + std::vector output_dim_sizes; + std::vector dims_are_dynamic; + for (int64 i = 0; i < shape.dims(); ++i) { + output_dim_sizes.push_back( + xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); + } + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); + if (unknown_index == -1) { + // No unknown index. + ctx->SetOutput(0, + xla::DynamicReshape(ctx->Input(0), output_dim_sizes, + shape.dim_sizes(), dims_are_dynamic)); + return; + } + auto common_factors = + xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes()); - shape_input.clear(); - // Run get input again, this time with dynamic dimension represented as - // "-1" - ctx->set_dynamic_dimension_is_minus_one(true); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); - - int dynamic_dimension = -1; - - for (int d = 0; d < num_dims; ++d) { - const int32 size = shape_input[d]; - if (size == -1) { - if (dynamic_dimension == -1) { - dynamic_dimension = d; + // Find common_factors that the input belongs to. + for (int64 i = 0; i < common_factors.size() - 1; ++i) { + auto start = common_factors[i]; + auto end = common_factors[i + 1]; + bool input_is_dynamic = false; + // product of all input dims in this group. E.g., in + // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group + // containing -1 will be 6. + xla::XlaOp product = xla::One(ctx->builder(), xla::S32); + for (int64 dim = start.first; dim < end.first; ++dim) { + if (input_xla_shape->is_dynamic_dimension(dim)) { + input_is_dynamic = true; + } + product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim)); + } + bool unknown_dim_in_group = false; + // The real size for the -1 dimension in a reshape. E.g., in + // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. + xla::XlaOp unknown_dim_size = product; + for (int64 dim = start.second; dim < end.second; ++dim) { + if (dim == unknown_index) { + unknown_dim_in_group = true; } else { - if (unknown_index != d) { - dynamic_dimension = d; - } + unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); } } - } - // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference - // in XLA to know which output dimension is dynamic. - ctx->SetOutput(0, xla::ReshapeWithInferredDimension( - ctx->Input(0), shape.dim_sizes(), dynamic_dimension)); + if (unknown_dim_in_group) { + // If input dim is dynamic, output dim at the -1 position must be + // dynamic. Similarly, if input dim is static, output dim has to be + // static at the -1 dimension. + dims_are_dynamic[unknown_index] = input_is_dynamic; + output_dim_sizes[unknown_index] = unknown_dim_size; + + ctx->SetOutput( + 0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes, + shape.dim_sizes(), dims_are_dynamic)); + VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() + << " to " << xla::VectorString(shape.dim_sizes()) + << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); + return; + } + } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 97359f81eee..d63b8146491 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -74,12 +74,44 @@ class UnsortedSegmentReduce : public XlaOpKernel { " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); + // data shape = [indices_shape, segment_shape] + // buffer shape = [num_segment, segment_shape] + // We now create the buffer shape by reverse enginerring data shape into + // indices shape and segment shape. TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); + auto buffer = xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + // Build dynamic dim sizes for buffer, as well as whether each dimension + // size is dynamic or static. We build two parts: num_sgement part and + // segment_shape part. + std::vector buffer_dims; + std::vector buffer_dims_are_dynamic; + // Build the "num_segment" part. + bool num_segments_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic)); + + buffer_dims.insert(buffer_dims.begin(), ctx->Input(2)); + buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(), + num_segments_is_dynamic); + // Build the segment shape part. + for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) { + buffer_dims.push_back(xla::GetDimensionSize(data, i)); + buffer_dims_are_dynamic.push_back( + ctx->InputXlaShape(0)->is_dynamic_dimension(i)); + } + + for (int64 i = 0; i < buffer_dims.size(); ++i) { + if (buffer_dims_are_dynamic[i]) { + // For each dynamic dimension, call set-dimension-size on it. + buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i); + } + } + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { return Combine(a, b); }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 85917af6a65..75faa2eac81 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -65,6 +67,47 @@ class ShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); +class XlaSetBoundOp : public XlaOpKernel { + public: + explicit XlaSetBoundOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape bound_shape = ctx->InputShape("bound"); + + OP_REQUIRES( + ctx, + ctx->InputType("bound") == DT_INT32 && + ctx->InputType("input") == DT_INT32, + errors::InvalidArgument( + "XlaSetBound can only set bound for int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, input_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + input_shape.DebugString())); + + OP_REQUIRES( + ctx, bound_shape.dims() == 0, + errors::InvalidArgument("XlaSetBound should only be used to set a " + "bound to the an int32 scalar value: got", + bound_shape.DebugString())); + int64 bound; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); + + xla::XlaOp result = xla::CustomCall( + ctx->builder(), "SetBound", {ctx->Input("input")}, + ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound)); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"), + XlaSetBoundOp); + class ShapeNOp : public XlaOpKernel { public: explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 46d4b70606e..a46cceddced 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/math/math_util.h" @@ -180,7 +181,7 @@ Status CompileImpl( } xla::Literal alg_literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); - auto alg = alg_literal.Get({}); + Algorithm alg = Algorithm(alg_literal.Get({})); if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) { return errors::InvalidArgument("Unsupported algorithm id: ", alg); } @@ -407,5 +408,80 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt") {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), StatefulUniformFullIntOp); +xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter, + xla::XlaOp delta) { + // Multiplying 256 to be consistent with the CPU/GPU kernels + delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256); + if (alg == RNG_ALG_PHILOX) { + return xla::PhiloxIncreaseCounter(counter, delta); + } else { + return counter + delta; + } +} + +xla::XlaOp PadRight(xla::XlaOp a, int n) { + return xla::Pad(a, xla::ScalarLike(a, 0), + xla::MakeEdgePaddingConfig({{0, n}})); +} + +template +class RngSkipOp : public XlaOpKernel { + public: + explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const int state_input_idx = 0; + const int alg_input_idx = 1; + const int delta_input_idx = 2; + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE, + &var_shape, &var)); + xla::Literal alg_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal)); + Algorithm alg = Algorithm(alg_literal.Get({})); + OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape)); + if (read_old_value) { + auto counter_size = GetCounterSize(alg); + xla::XlaOp output = var; + if (RNG_MAX_COUNTER_SIZE > counter_size) { + // Because the size of `var` depends on the algorithm while we want the + // output to have a fixed size (to help shape inference), we fix the + // output size to be the maximal state size among algorithms, and right- + // pad it with zeros if var's size is smaller than that. + output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size); + } + ctx->SetOutput(0, output); + } + xla::XlaOp counter; + xla::XlaOp key; + std::tie(counter, key) = StateAndKeyFromVariable(alg, var); + xla::XlaOp delta = ctx->Input(delta_input_idx); + delta = BitcastConvertType(delta, xla::U64); + auto new_counter = IncreaseCounter(alg, counter, delta); + var = StateAndKeyToVariable(alg, new_counter, key); + xla::PrimitiveType state_element_type; + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); + var = BitcastConvertType(var, state_element_type); + OP_REQUIRES_OK( + ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp); +}; + +REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"), + RngSkipOp<>); + +using RngReadAndSkipOp = RngSkipOp; + +REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"), + RngReadAndSkipOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 13c3dbe489e..e606812bc4e 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -111,6 +111,8 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, } } +namespace { + xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, xla::XlaOp seeds, const xla::Shape& shape) { @@ -140,8 +142,6 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, } } -namespace { - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc new file mode 100644 index 00000000000..e46fec3c576 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -0,0 +1,485 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/stateless_random_ops_v2.h" + +#include + +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { + +namespace { + +inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) { + if (alg == RNG_ALG_PHILOX) { + return xla::RandomAlgorithm::RNG_PHILOX; + } + return xla::RandomAlgorithm::RNG_THREE_FRY; +} + +inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) { + if (alg == xla::RandomAlgorithm::RNG_PHILOX) { + return RNG_ALG_PHILOX; + } + return RNG_ALG_THREEFRY; +} + +xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) { + Algorithm alg_ = RandomAlgorithmToAlgorithm(alg); + return xla::Slice(state, {RNG_KEY_SIZE}, + {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1}); +} + +xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key, + xla::XlaOp counter, const xla::Shape& shape) { + key = BitcastConvertType(key, xla::U64); + counter = BitcastConvertType(counter, xla::U64); + xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0); + xla::XlaOp result = xla::RngBitGenerator(alg, state, shape); + auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0)); + new_counter = BitcastConvertType(new_counter, xla::S64); + return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), + /*state=*/new_counter}; +} + +std::tuple GetKeyCounterAlg( + absl::string_view device_type_string, xla::XlaOp key) { + // The Philox algorithm may cause performance regression on other devices. + // Turn on the Philox algorithm for the CPU and GPU backends only. + if (device_type_string == DEVICE_GPU_XLA_JIT || + device_type_string == DEVICE_CPU_XLA_JIT) { + auto counter_key = xla::ScramblePhiloxKey(key); + return std::make_tuple(counter_key.second, counter_key.first, + RNG_ALG_PHILOX); + } else { + auto counter_shape = + xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE}); + auto counter = xla::Zeros(key.builder(), counter_shape); + return std::make_tuple(key, counter, RNG_ALG_THREEFRY); + } +} + +} // namespace + +xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg, + xla::XlaOp key, xla::XlaOp counter, + const xla::Shape& shape, xla::XlaOp minval, + xla::XlaOp maxval) { + xla::XlaBuilder* builder = key.builder(); + xla::PrimitiveType type = shape.element_type(); + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + auto generator = std::bind(BitGenerator, alg, _1, _2, _3); + switch (type) { + case xla::F32: + case xla::F64: + return xla::UniformFloatingPointDistribution(key, counter, generator, + minval, maxval, shape); + case xla::S32: + case xla::S64: + case xla::U32: + case xla::U64: + return UniformIntDistribution(key, counter, generator, minval, maxval, + shape); + break; + default: + return {builder->ReportError(xla::Unimplemented( + "Types other than F32, S32, S64, U32 and U64 are not " + "implemented by " + "StatelessRngUniformV2; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + counter}; + } +} + +namespace { + +xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg, + xla::XlaOp key, xla::XlaOp counter, + const xla::Shape& shape) { + xla::XlaBuilder* builder = key.builder(); + + xla::PrimitiveType type = shape.element_type(); + xla::RngOutput output = BitGenerator(alg, key, counter, shape); + switch (type) { + case xla::U32: + case xla::U64: + return output; + case xla::S32: + case xla::S64: + return xla::RngOutput{BitcastConvertType(output.value, type), + output.state}; + default: + return { + builder->ReportError(xla::Unimplemented( + "Types other than U32, S32, U64 and S64 are not implemented by " + "StatelessRngUniformFullInt; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + output.state}; + } +} + +Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx, + xla::RandomAlgorithm* alg) { + auto alg_shape = ctx->InputShape(alg_input_idx); + if (alg_shape.dims() != 0) { + return errors::InvalidArgument("algorithm must be of shape [], not ", + alg_shape.DebugString()); + } + xla::Literal alg_literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); + auto alg_ = Algorithm(alg_literal.Get({})); + *alg = AlgorithmToRandomAlgorithm(alg_); + return Status::OK(); +} + +xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg, + TensorShape const& counter_shape, + xla::XlaOp counter) { + auto input_counter_size = counter_shape.dim_size(0); + auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg)); + if (input_counter_size > real_counter_size) { + counter = xla::Slice(counter, {0}, {real_counter_size}, {1}); + } + return counter; +} + +class StatelessRandomUniformOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + + auto result = StatelessRngUniformV2( + alg, key, counter, xla_shape, + xla::ConstantR0WithType(builder, rng_primitive_type, 0.0), + xla::ConstantR0WithType(builder, rng_primitive_type, 1.0)); + auto uniform = MaybeConvertF32ToBF16(result.value, dtype_); + ctx->SetOutput(0, uniform); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessRandomUniformOp); + +class StatelessRandomUniformIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + const int minval_input_idx = 4; + const int maxval_input_idx = 5; + TensorShape minval_shape = ctx->InputShape(minval_input_idx); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), + errors::InvalidArgument("minval must be scalar, got shape ", + minval_shape.DebugString())); + TensorShape maxval_shape = ctx->InputShape(maxval_input_idx); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), + errors::InvalidArgument("maxval must be scalar, got shape ", + maxval_shape.DebugString())); + + xla::XlaOp minval = ctx->Input(minval_input_idx); + xla::XlaOp maxval = ctx->Input(maxval_input_idx); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = + StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval); + ctx->SetOutput(0, result.value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}), + StatelessRandomUniformIntOp); + +class StatelessRandomUniformFullIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape); + ctx->SetOutput(0, result.value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}), + StatelessRandomUniformFullIntOp); + +class StatelessRandomNormalOp : public XlaOpKernel { + public: + explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + auto generator = std::bind(BitGenerator, alg, _1, _2, _3); + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = xla::NormalFloatingPointDistribution(key, counter, generator, + xla_shape); + auto normal = MaybeConvertF32ToBF16(result.value, dtype_); + ctx->SetOutput(0, normal); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomNormalV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessRandomNormalOp); + +class StatelessTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const int key_input_idx = 1; + const int counter_input_idx = 2; + const int alg_input_idx = 3; + xla::XlaOp key = ctx->Input(key_input_idx); + xla::XlaOp counter = ctx->Input(counter_input_idx); + + xla::RandomAlgorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg)); + + auto counter_shape = ctx->InputShape(counter_input_idx); + OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg), + ctx->InputShape(key_input_idx), + counter_shape)); + + xla::XlaBuilder* builder = ctx->builder(); + + DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); + + counter = MaybeSliceCounter(alg, counter_shape, counter); + auto result = StatelessRngUniformV2( + alg, key, counter, xla_shape, + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); + xla::XlaOp truncated_normal = TruncatedNormal(result.value); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + ctx->SetOutput(0, truncated_normal); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2") + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("alg") + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + StatelessTruncatedNormalOp); + +class GetKeyCounterAlgOp : public XlaOpKernel { + public: + explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::XlaOp seed = ctx->Input(0); + + xla::XlaBuilder* builder = seed.builder(); + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key); + key = std::get<0>(key_counter_alg); + auto counter = std::get<1>(key_counter_alg); + auto alg = std::get<2>(key_counter_alg); + key = xla::Reshape(key, {RNG_KEY_SIZE}); + ctx->SetOutput(0, key); + ctx->SetOutput(1, counter); + ctx->SetOutput(2, ConstantR0(builder, static_cast(alg))); + } + + private: + string device_type_string_; + + TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp); +}; + +REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 784b790767c..943d92982cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" +#include + +#include "absl/algorithm/container.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -23,16 +26,20 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { namespace { +using errors::InvalidArgument; class StridedSliceOp : public XlaOpKernel { public: @@ -48,7 +55,6 @@ class StridedSliceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); const TensorShape begin_shape = ctx->InputShape("begin"); - OP_REQUIRES( ctx, begin_shape.dims() == 1, errors::InvalidArgument("'begin' input has to be a rank 1 vector")); @@ -78,20 +84,24 @@ class StridedSliceOp : public XlaOpKernel { TensorShape final_shape; PartialTensorShape dummy_processing_shape, partial_final_shape; bool dummy = false; - OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( - begin_is_constant ? &begin_tensor : nullptr, - end_is_constant ? &end_tensor : nullptr, - strides_tensor, input_shape, begin_mask_, end_mask_, - ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &dummy_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); + absl::InlinedVector output_to_sparse_mapping; + absl::InlinedVector output_to_processing_mapping; + OP_REQUIRES_OK( + ctx, + ValidateStridedSliceOp( + begin_is_constant ? &begin_tensor : nullptr, + end_is_constant ? &end_tensor : nullptr, strides_tensor, + input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides, + &output_to_sparse_mapping, &output_to_processing_mapping)); - OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape), - errors::InvalidArgument( - "XLA can't deduce compile time constant output " - "shape for strided slice: ", - partial_final_shape.DebugString(), - ", output shape must be a compile-time constant")); + OP_REQUIRES( + ctx, partial_final_shape.AsTensorShape(&final_shape), + InvalidArgument("XLA can't deduce compile time constant output " + "shape for strided slice: ", + partial_final_shape.DebugString(), + ", output shape must be a compile-time constant")); xla::XlaOp slice = ctx->Input(0); if (begin_is_constant && end_is_constant) { @@ -119,69 +129,84 @@ class StridedSliceOp : public XlaOpKernel { auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.ValueOrDie(); - if (xla_shape.is_static()) { - // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + std::vector begins_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic)); + std::vector ends_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic)); + bool begins_are_static = absl::c_all_of( + begins_are_dynamic, [](bool dynamic) { return !dynamic; }); + OP_REQUIRES(ctx, begins_are_static, + errors::InvalidArgument( + "XLA can't use dynamic begin values for slice.")); + bool ends_are_static = absl::c_all_of( + ends_are_dynamic, [](bool dynamic) { return !dynamic; }); + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + if (xla_shape.is_static() && ends_are_static) { ctx->SetOutput(0, slice); return; } - auto input_dim_sizes = input_shape.dim_sizes(); - for (int64 i = 0; i < xla_shape.rank(); ++i) { - if (xla_shape.is_dynamic_dimension(i)) { - input_dim_sizes[i] = -1; + for (int64 i = 0; i < final_shape.dims(); ++i) { + int64 input_index = output_to_processing_mapping[i]; + if (input_index == -1) { + continue; } - } - PartialTensorShape input_partial_shape(input_dim_sizes); - partial_final_shape.Clear(); - end.clear(); - strides.clear(); - begin.clear(); - // Run shape inferenference again with partial shape. - OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - input_partial_shape, begin_mask_, end_mask_, - ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &dummy_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); - if (partial_final_shape.AsTensorShape(&final_shape)) { - // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); - ctx->SetOutput(0, slice); - return; - } + bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index); - // We consider slicing a dynamic tensor t with negative indices as a - // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n - for (int64 i = 0; i < partial_final_shape.dims(); ++i) { - bool dynamic_dim = partial_final_shape.dim_size(i) - 1; - bool backward_slice = end[i] < 0; - if (dynamic_dim && backward_slice) { + int64 sparse_index = output_to_sparse_mapping[i]; + bool end_is_dynamic = + sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; + bool backward_slice = sparse_index == -1 + ? false + : end_literal.Get({sparse_index}) < 0; + if ((input_is_dynamic && backward_slice) || end_is_dynamic) { OP_REQUIRES( - ctx, strides[i] == 1, + ctx, strides[input_index] == 1, errors::InvalidArgument("XLA has not implemented dynamic " "sized slice with non-trival stride yet. " "Please file a bug against XLA")); - - OP_REQUIRES(ctx, begin[i] >= 0, - errors::InvalidArgument( - "XLA has not implemented dynamic " - "sized slice with negative begin index %lld. " - "Please file a bug against XLA", - begin[i])); // If there is a dynamic dimension, properly set dimension size of // the result. - auto operand_size = xla::GetDimensionSize(ctx->Input(0), i); - - operand_size = xla::Add( - operand_size, xla::ConstantR0(ctx->builder(), end[i])); + auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index); + if (backward_slice) { + // We consider slicing a dynamic tensor t with negative indices as + // a dynamic sized slice. E.g., t[: -n], the result length is + // shape(t) - n. + OP_REQUIRES(ctx, !end_is_dynamic, + errors::InvalidArgument( + "XLA has not implemented dynamic " + "sized slice with dynamic negative index %lld. ")); + operand_size = xla::Add( + operand_size, + xla::ConstantR0(ctx->builder(), + end_literal.Get({sparse_index}))); + } else { + // The end of slice with dynamic slice size is the min of operand + // shape and slice size. E.g., t[:end_size], result size is + // min(shape(t), end_size). + xla::XlaOp end_size; + if (end_is_dynamic) { + end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index}, + {sparse_index + 1}, {1}), + {}); + } else { + end_size = + xla::ConstantR0(ctx->builder(), end[input_index]); + } + operand_size = xla::Min(operand_size, end_size); + } slice = xla::SetDimensionSize( slice, - xla::Sub(operand_size, - xla::ConstantR0(ctx->builder(), begin[i])), + xla::Sub(operand_size, xla::ConstantR0( + ctx->builder(), begin[input_index])), i); } } + ctx->SetOutput(0, slice); + return; } else { // When output shape is fully defined, it must be a size one slice: // @@ -239,9 +264,9 @@ class StridedSliceOp : public XlaOpKernel { std::vector output_shape_dim_sizes; slice = xla::DynamicSlice(slice, start_indices, slice_sizes); + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); } - slice = xla::Reshape(slice, final_shape.dim_sizes()); - ctx->SetOutput(0, slice); } private: @@ -267,6 +292,83 @@ class StridedSliceGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); } + // When the begin / end is unknown, compile the gradient into dynamic update + // slice into a broadcasted 0s. + // + // Broadcasted 0 + // +----------------------+ + // | +----+ | + // |<-begin->|grad|<-end->| <== Dynamic update grad into 0s. + // | +----+ | + // +----------------------+ + void CompileAsDynamicUpdateSlice(XlaOpKernelContext* ctx, + const TensorShape& input_shape, + const xla::Literal& strides_literal) { + bool dummy = false; + Tensor strides_tensor; + PartialTensorShape processing_shape, final_shape; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; + + absl::InlinedVector output_to_sparse_mapping; + absl::InlinedVector output_to_processing_mapping; + + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + nullptr, nullptr, strides_tensor, input_shape, begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &processing_shape, &final_shape, &dummy, &dummy, &dummy, + &begin, &end, &strides, &output_to_sparse_mapping, + &output_to_processing_mapping)); + for (int64 i = 0; i < processing_shape.dims(); ++i) { + OP_REQUIRES( + ctx, strides[i] == 1, + errors::InvalidArgument("Strides in strided slice grad have to be " + "one when inputs are not constant.")); + } + + auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); + zero = xla::Broadcast(zero, input_shape.dim_sizes()); + xla::XlaOp grad = ctx->Input(4); + xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie(); + // Undo any new/shrink axes. + VLOG(1) << "xla grad shape" << grad_shape; + VLOG(1) << "input_shape" << input_shape.DebugString(); + std::vector begins(processing_shape.dims(), + xla::Zero(ctx->builder(), xla::S32)); + for (int64 i = 0; i < grad_shape.rank(); ++i) { + // Use grad shape, which is known, to update unknown processing shape. + // Grad shape is the output of the ValidateStridedSliceOp function in + // forward pass, thus we use output_to_processing_mapping. + if (output_to_processing_mapping[i] != -1) { + processing_shape.set_dim(output_to_processing_mapping[i], + grad_shape.dimensions(i)); + } + + // Similarly, use output_to_sparse_mapping to find out corresponding + // begin dim of the output, as indices for dynamic update slice. + int64 begin_dim = output_to_sparse_mapping[i]; + if (begin_dim != -1) { + auto begin_index = + xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1}); + auto begin_index_scalar = xla::Reshape( + xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index); + begins[output_to_sparse_mapping[i]] = begin_index_scalar; + } + } + VLOG(1) << "processing_shape" << processing_shape.DebugString(); + TensorShape full_processing_shape; + OP_REQUIRES(ctx, processing_shape.AsTensorShape(&full_processing_shape), + errors::InvalidArgument( + "Processing shape ", processing_shape.DebugString(), + " can't be fully inferred from grad shape")); + grad = xla::Reshape(grad, full_processing_shape.dim_sizes()); + grad = xla::DynamicUpdateSlice(zero, grad, begins); + ctx->SetOutput(0, grad); + } void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; absl::InlinedVector begin; @@ -275,12 +377,15 @@ class StridedSliceGradOp : public XlaOpKernel { TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - xla::Literal begin_literal, end_literal, strides_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); - OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok(); + bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok(); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + if (!(begin_is_constant && end_is_constant)) { + CompileAsDynamicUpdateSlice(ctx, input_shape, strides_literal); + return; + } Tensor begin_tensor, end_tensor, strides_tensor; OP_REQUIRES_OK( ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); @@ -423,7 +528,12 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape lhs_shape; xla::XlaOp lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + } else { + lhs_shape = ctx->InputShape(0); + lhs = ctx->Input(0); + } const TensorShape rhs_shape = ctx->InputShape(4); @@ -481,7 +591,11 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + } else { + ctx->SetOutput(0, lhs); + } } private: @@ -497,5 +611,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") .CompileTimeConstantInput("strides"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("TensorStridedSliceUpdate") + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), + StridedSliceAssignOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 976ff91f6ce..1ea0e797675 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -45,22 +45,32 @@ namespace tensorflow { namespace { // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist -// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a -// dimension is static, a constant dimension is returned. +// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a +// dimension is static, a constant dimension is returned. If a dim is dynamic, a +// dynamic XlaOp representing the dynamic size is returned. xla::StatusOr>> GetTensorListDynamicDims( XlaOpKernelContext* ctx, const xla::Shape& element_shape, const xla::Shape& list_shape, int64 num_elements) { std::vector dynamic_sizes; - ctx->set_dynamic_dimension_is_minus_one(true); // The multiplier can be a dynamic value. TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes)); + std::vector dims_are_dynamic; + TF_RETURN_IF_ERROR( + ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic)); + bool leading_dim_is_dynamic; + TF_RETURN_IF_ERROR( + ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic)); std::vector> list_dynamic_dims; // Set dynamic dimension size to 0 for initialization value. std::vector dynamic_dims; - // Leading dim is a static dimension. - dynamic_dims.push_back(xla::ConstantR0(ctx->builder(), num_elements)); + if (leading_dim_is_dynamic) { + dynamic_dims.push_back(ctx->Input(1)); + } else { + dynamic_dims.push_back( + xla::ConstantR0(ctx->builder(), num_elements)); + } for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) { - if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) { + if (dims_are_dynamic[dim]) { auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1}); dynamic_dim_size = xla::Reshape(dynamic_dim_size, {}); dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32); @@ -80,11 +90,12 @@ class TensorListLengthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 leading_dim; - OP_REQUIRES_OK(ctx, - GetLeadingDimForTensorList(ctx->Input(0), &leading_dim)); - Tensor length_tensor(DT_INT32, {}); - length_tensor.scalar()() = static_cast(leading_dim); - ctx->SetConstantOutput(0, length_tensor); + xla::XlaOp leading_dim_size; + bool leading_dim_is_dynamic; + OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim, + &leading_dim_is_dynamic, + &leading_dim_size)); + ctx->SetOutput(0, leading_dim_size); } private: @@ -134,6 +145,9 @@ class TensorListReserveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + bool num_element_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic)); OP_REQUIRES( ctx, num_elements >= 0, errors::InvalidArgument( @@ -156,7 +170,8 @@ class TensorListReserveOp : public XlaOpKernel { if (got_shape) { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( - element_shape, num_elements, &list_shape)); + element_shape, num_elements, + num_element_is_dynamic, &list_shape)); // Set up dynamic dimension sizes to create the zero tensor. auto list_dynamic_dims_or = GetTensorListDynamicDims( ctx, element_shape, list_shape, num_elements); @@ -175,8 +190,8 @@ class TensorListReserveOp : public XlaOpKernel { return; } - xla::XlaOp result = - BuildUninitializedTensorList(ctx->builder(), num_elements); + xla::XlaOp result = BuildUninitializedTensorList( + ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1)); ctx->SetTensorListOutput(0, result); } @@ -200,6 +215,9 @@ class EmptyTensorListOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); + bool num_element_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic)); OP_REQUIRES(ctx, max_num_elements >= 0, errors::InvalidArgument( "XLA compilation requires a fixed tensor list size. Set " @@ -210,9 +228,9 @@ class EmptyTensorListOp : public XlaOpKernel { if (dtype_ != DT_VARIANT) { // We are creating a non-nested TensorList. - // If element shape is compile time constant and it's not "unknown rank" - // shape (-1), create an initialized TensorList. Otherwise create an - // uninitialized TensorList. + // If element shape is compile time constant and it's not "unknown + // rank" shape (-1), create an initialized TensorList. Otherwise + // create an uninitialized TensorList. xla::XlaOp element_shape_handle = ctx->Input(0); xla::PrimitiveType type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type)); @@ -224,7 +242,8 @@ class EmptyTensorListOp : public XlaOpKernel { if (got_shape) { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( - element_shape, max_num_elements, &list_shape)); + element_shape, max_num_elements, + num_element_is_dynamic, &list_shape)); // Set up dynamic dimension sizes to create the zero tensor. auto list_dynamic_dims_or = GetTensorListDynamicDims( ctx, element_shape, list_shape, max_num_elements); @@ -243,7 +262,8 @@ class EmptyTensorListOp : public XlaOpKernel { // We are creating a nested TensorList or a non-nested TensorList with // unknown shape. Just create an uninitialized TensorList. xla::XlaOp result = - BuildUninitializedTensorList(ctx->builder(), max_num_elements); + BuildUninitializedTensorList(ctx->builder(), max_num_elements, + num_element_is_dynamic, ctx->Input(1)); ctx->SetTensorListOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index aa71e4d4364..156f9bfea40 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -189,28 +189,42 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, } xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, - int64 leading_dimension) { + int64 leading_dimension, + bool leading_size_is_dynamic, + xla::XlaOp leading_dim_size) { auto zero = xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); - return xla::Broadcast(zero, std::vector{leading_dimension}); + auto broadcast = xla::Broadcast(zero, std::vector{leading_dimension}); + if (leading_size_is_dynamic) { + return xla::SetDimensionSize(broadcast, leading_dim_size, 0); + } else { + return broadcast; + } } -Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) { +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); if (is_initialized) { auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); + *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0); + auto buffer = xla::GetTupleElement(list, 0); *leading_dim = buffer_shape.dimensions(0); + *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0); } else { + *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0); *leading_dim = list_shape.dimensions(0); + *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0); } return Status::OK(); } Status GetTensorListShapeFromElementTensorListShape( const xla::Shape& element_tensor_list_shape, int64 leading_dim, - xla::Shape* tensor_list_shape) { + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) { std::vector shapes; int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape); for (int i = 0; i < tuple_size; i++) { @@ -220,6 +234,9 @@ Status GetTensorListShapeFromElementTensorListShape( dimensions.insert(dimensions.begin(), leading_dim); shapes.push_back( xla::ShapeUtil::MakeShape(shape.element_type(), dimensions)); + if (leading_dim_is_dynamic) { + shapes.back().set_dynamic_dimension(0, true); + } } shapes.push_back( xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); @@ -229,6 +246,7 @@ Status GetTensorListShapeFromElementTensorListShape( Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, int64 leading_dim, + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) { if (!element_shape.IsArray()) { return errors::InvalidArgument( @@ -236,12 +254,12 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, "shape. But element shape is ", element_shape.DebugString()); } - std::vector shapes; std::vector dimensions = xla::SpanToVector(element_shape.dimensions()); dimensions.insert(dimensions.begin(), leading_dim); shapes.push_back( xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions)); + shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic); shapes.push_back( xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); @@ -279,7 +297,10 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, bool element_is_tensor_list, xla::XlaOp* initialized_list) { int64 leading_dim; - TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim)); + xla::XlaOp leading_dim_dynamic_size; + bool leading_dim_is_dynamic; + TF_RETURN_IF_ERROR(GetLeadingDimForTensorList( + list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size)); xla::XlaBuilder* b = list.builder(); xla::Shape list_shape; @@ -287,12 +308,11 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, if (element_is_tensor_list) { TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape( - element_shape, leading_dim, &list_shape)); + element_shape, leading_dim, leading_dim_is_dynamic, &list_shape)); } else { TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape( - element_shape, leading_dim, &list_shape)); + element_shape, leading_dim, leading_dim_is_dynamic, &list_shape)); } - bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (is_initialized) { @@ -312,8 +332,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { std::vector dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); - // Leading dim is a static dimension. - dynamic_dims.push_back(xla::ConstantR0(b, leading_dim)); + dynamic_dims.push_back(leading_dim_dynamic_size); xla::XlaOp sub_element; if (element_is_tensor_list) { sub_element = xla::GetTupleElement(element, i); @@ -504,7 +523,9 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); - for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) { + // Propagate dynamic dimensions from buffer to the sliced buffer, except for + // leading dimension (which is always static 1). + for (int64 i = 1; i < buffer_shape.dimensions_size(); ++i) { if (buffer_shape.is_dynamic_dimension(i)) { auto buffer = xla::GetTupleElement(list, 0); auto gds = xla::GetDimensionSize(buffer, i); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index ef3c8badf71..549ccd5aece 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -60,17 +60,22 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, // Returns an uninitialized TensorList. xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, - int64 leading_dimension); + int64 leading_dimension, + bool leading_size_is_dynamic, + xla::XlaOp leading_dim_size); -// Returns leading dimension for the TensorList. -// Input can be initialized or uninitialized TensorList. -// Non-nested and nested TensorLists are both supported. -Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim); +// Returns leading dimension for the TensorList as well as a dynamic op +// representing the dynamic size. Input can be initialized or uninitialized +// TensorList. Non-nested and nested TensorLists are both supported. +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size); // Returns TensorList shape for the element shape. // Element shape must be a normal tensor shape. Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, int64 leading_dim, + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape); // Returns a TensorList filled by zeros with the given shape. diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index fe7a5898011..a94411f1b30 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -513,10 +513,26 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Prepare dynamic dimensions for element shapes. std::vector> list_dynamic_dims; for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { - // Set dynamic dimension size to 0 for initilization value. std::vector dynamic_dims; + const xla::Shape& shape = list_shape.tuple_shapes(i); - for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) { + + // We already have the dynamic size of leading dimension outside of + // the while loop without initializing the TensorList inside the while + // loop. + if (shape.is_dynamic_dimension(0)) { + xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0); + dynamic_dims.push_back(leading_dim_size); + } else { + int32 dim_size = shape.dimensions(0); + dynamic_dims.push_back( + xla::ConstantR0(ctx->builder(), dim_size)); + } + + // Set dynamic dimension size to 0 for element value. Inside the while + // loop, TensorlistSetItem will properly set the element shape's + // dynamic diemnsion. + for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) { int32 dim_size = shape.dimensions(dim); if (shape.is_dynamic_dimension(dim)) { dim_size = 0; diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index e5913a8bbf3..eb1ab79d165 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -62,7 +62,7 @@ xla::StatusOr Expand(xla::XlaOp input, int64 dim) { std::vector expanded_shape = xla::SpanToVector(input_shape.dimensions()); expanded_shape[dim] /= 4; - expanded_shape.insert(expanded_shape.begin() + dim, 4); + expanded_shape.insert(expanded_shape.begin() + dim + 1, 4); // Move the newly created dimension to the end with a transpose. std::vector permutation; diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index abaeb305104..92a83436346 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -90,18 +90,6 @@ Status ConvertOutputInfo(const tf2xla::Config& config, return ParseOutputArrayInfo(array_names, &specs->outputs); } -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; -} - } // namespace Status ConvertGraphDefToXlaViaMlir( @@ -150,7 +138,6 @@ Status ConvertGraphDefToXlaViaMlir( } } - RegisterDialects(); mlir::MLIRContext context; TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f4b9e9654d2..2f895b17219 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -291,6 +291,16 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaSetBound") + .Input("input: int32") + .Input("bound: int32") + .Output("output: int32") + .SetShapeFn(shape_inference::UnknownShape) + .Doc( + R"doc(Set a bound for the given input value as a hint to Xla compiler, + returns the same value. +)doc"); + REGISTER_OP("XlaDynamicSlice") .Input("input: T") .Input("start_indices: Tindices") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 846dafa2570..19104518b71 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -387,6 +387,14 @@ def reduce_window(operand, replica_id = gen_xla_ops.xla_replica_id +# Set a static bound for the given input value as a hint to Xla compiler, +# returns the same value. +# Usage: +# def f(t, p): +# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. +# return t[:p] # xla knows the bound of the slice is 3. +set_bound = gen_xla_ops.xla_set_bound + def reshape(x, new_sizes, dimensions=None, name=None): if dimensions is not None: diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 2db431c0413..860c3a40424 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -83,6 +83,8 @@ CreateResourceOpInfoMap() { add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("RngReadAndSkip" , kReadWrite, kVariable); + add("RngSkip" , kReadWrite, kVariable); add("StatefulStandardNormalV2" , kReadWrite, kVariable); add("StatefulTruncatedNormal" , kReadWrite, kVariable); add("StatefulUniform" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 242a2b04ab9..3cf9df64b0b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -137,7 +137,6 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { const auto& it = node.attr().find("allowed_devices"); if (it != node.attr().end()) { if (!it->second.list().s().empty()) { - // TODO(b/149512838): Support non-empty allowed devices. return errors::InvalidArgument( "VarHandleOp with non-empty allowed devices is not supported."); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 635b7170d82..b22dc05eaa1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" @@ -732,7 +734,7 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( - std::move(*graph), {args.data(), args.size()}, + std::move(*graph), mlir::SpanToArrayRef(args), options_.device_type.type_string(), options.use_tuple_arg, *options_.flib_def, debug_info, options_.shape_representation_fn, result)); @@ -990,20 +992,6 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (int i = 0, end = input_to_args->size(); i < end; ++i) { - const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { - int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); - VLOG(1) << "Setting dynamic binding " << i << " -> " - << dynamic_size_param_index; - - TF_RETURN_IF_ERROR(builder->SetDynamicBinding( - /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, - /*target_param_num=*/0, /*target_param_index=*/{i}, - dim_and_arg_num.first)); - } - } - for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_shardings.find(i); xla::XlaScopedShardingAssignment assign_sharding( @@ -1035,16 +1023,17 @@ Status XlaCompiler::BuildArguments( absl::StrCat("arg", i)); } } + } - for (int i = 0, end = input_to_args->size(); i < end; ++i) { - const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { - int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); - TF_RETURN_IF_ERROR(builder->SetDynamicBinding( - /*dynamic_size_param_num=*/dynamic_size_param_index, {}, - /*target_param_num=*/i, /*target_param_index=*/{}, - dim_and_arg_num.first)); - } + for (int i = 0, end = input_to_args->size(); i < end; ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + VLOG(1) << "Setting dynamic size " << i << " -> " + << dynamic_size_param_index; + arg_handles[i] = xla::SetDimensionSize( + arg_handles[i], arg_handles[dynamic_size_param_index], + dim_and_arg_num.first); } } @@ -1155,7 +1144,11 @@ Status ValidateGraph(const Graph* graph, return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", - s.error_message(), ")", FormatNodeForError(*node))); + s.error_message(), ")", FormatNodeForError(*node), + "One approach is to outside compile the unsupported ops to run on " + "CPUs by enabling soft placement " + "`tf.config.set_soft_device_placement(True)`." + " This has a potential performance penalty.")); } return Status::OK(); }; @@ -1370,8 +1363,15 @@ Status XlaCompiler::SetDeviceToHostMetadata( const string& key, absl::Span types, absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { - return errors::InvalidArgument( - "Duplicate calls to SetDeviceToHostMetadata with key ", key); + tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key]; + tf2xla::HostTransferMetadata new_transfer; + SetTransfer(key, types, shapes, &new_transfer); + if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + return Status::OK(); + } else { + return errors::InvalidArgument( + "Duplicate calls to SetDeviceToHostMetadata with key ", key); + } } tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; SetTransfer(key, types, shapes, &transfer); @@ -1396,9 +1396,16 @@ Status XlaCompiler::GetDeviceToHostShapes( Status XlaCompiler::SetHostToDeviceMetadata( const string& key, absl::Span types, absl::Span shapes) { - if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { - return errors::InvalidArgument( - "Duplicate calls to SetHostToDeviceMetadata with key ", key); + if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) { + tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key]; + tf2xla::HostTransferMetadata new_transfer; + SetTransfer(key, types, shapes, &new_transfer); + if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + return Status::OK(); + } else { + return errors::InvalidArgument( + "Duplicate calls to SetHostToDeviceMetadata with key ", key); + } } tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; SetTransfer(key, types, shapes, &transfer); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b0d93cde846..762700eaea8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -129,8 +129,6 @@ class XlaCompiler { // Resource updates are converted into input / output of xla. The two // buffers are aliased with other if this option is true. - // - // Currently only supports TPU. bool alias_resource_update = false; }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5df508d60b3..f348552050b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1897,5 +1897,63 @@ TEST_F(XlaCompilerTest, AliasResourceUpdates) { EXPECT_EQ(alias.entries(0).parameter_number(), 0); } +// Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata +// is not an error. +TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); +} + +// Tests that passing in a mismatched duplicate input to +// SetDeviceToHostMeatadata is not an error. +TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + std::vector types2{DT_FLOAT}; + std::vector shapes2{TensorShape({1})}; + + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); + Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2); + EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); +} + +// Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata +// is not an error. +TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); +} + +// Tests that passing in a mismatched duplicate input to +// SetHostToDeviceMeatadata is not an error. +TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + std::vector types2{DT_FLOAT}; + std::vector shapes2{TensorShape({1})}; + + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); + Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2); + EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 34e108bb6bf..f0cc8d26709 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -101,6 +101,48 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { }); } +xla::StatusOr XlaExpression::ResolveDynamism( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: { + // Constant values are considered static. + Tensor constant_false(DT_BOOL, constant_value().shape()); + auto flat = constant_false.flat(); + for (int64 i = 0; i < flat.size(); ++i) flat(i) = false; + return constant_false; + } + case Kind::kXlaOp: + break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; + case Kind::kResource: + TF_FALLTHROUGH_INTENDED; + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveDynamism called on unsupported XlaExpression: ", + HumanString()); + } + + if (!client) + return errors::InvalidArgument("client is required to resolve constant"); + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildDynamicInferenceGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor(DT_BOOL); + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor)); + return tensor; +} + xla::StatusOr> XlaExpression::ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one) const { switch (kind()) { diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 3010964c5b7..3546368ff7b 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -99,6 +99,10 @@ class XlaExpression { xla::StatusOr> ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one = false) const; + // ResolveDynamism computes where a value inside this op is dynamic or can be + // inferred at compile time. + xla::StatusOr ResolveDynamism(xla::Client* client) const; + // Returns the shape of the tensor. // The shape of a resource is the shape of a resource handle (i.e., a scalar), // not the shape of the resource's value. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 735a6c7291e..c2d1906e47a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -243,6 +243,74 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { return LiteralToFloat64Scalar(literal, out); } +static Status LiteralToPredVector(const xla::LiteralSlice& literal, + std::vector* out) { + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); + } + int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); + if (literal.shape().element_type() != xla::PRED) { + return errors::InvalidArgument("value is not PRED"); + } + for (int64 i = 0; i < size; ++i) { + out->push_back(literal.Get({i})); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { + xla::Literal literal; + XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; + xla::StatusOr dynamism_or_status = e.ResolveDynamism(client); + if (!dynamism_or_status.ok()) { + Status status = dynamism_or_status.status(); + errors::AppendToMessage(&status, "while evaluating input dynamism", index, + " of ", context_->op_kernel().type_string()); + return status; + } + Tensor dynamism = dynamism_or_status.ValueOrDie(); + + Tensor temp(dynamism.dtype()); + TensorShape tensor_shape({}); + if (!temp.CopyFrom(dynamism, tensor_shape)) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape); + } + + TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); + *out = literal.Get({}); + return Status::OK(); +} + +Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( + int index, std::vector* out) { + xla::Literal literal; + XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; + xla::StatusOr dynamism_or_status = e.ResolveDynamism(client); + if (!dynamism_or_status.ok()) { + Status status = dynamism_or_status.status(); + errors::AppendToMessage(&status, "while evaluating input dynamism", index, + " of ", context_->op_kernel().type_string()); + return status; + } + Tensor dynamism = dynamism_or_status.ValueOrDie(); + + Tensor temp(dynamism.dtype()); + TensorShape tensor_shape({InputShape(index).num_elements()}); + if (!temp.CopyFrom(dynamism, tensor_shape)) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape); + } + + TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); + return LiteralToPredVector(literal, out); +} + // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3cf51e6ec6f..1ed343ba20f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -116,7 +116,10 @@ class XlaOpKernelContext { // returns a one-element list. Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); - + // Evaluates input and returns their dynamism vector in a vector of + // predicates. + Status ResolveInputDynamismIntoPredVector(int index, std::vector* out); + Status ResolveInputDynamismIntoPred(int index, bool* out); // Helper methods for constant inputs. // Evaluates input `index` and stores it in `*constant_literal`. If the diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e37f4659185..9948fe6d1b9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -365,6 +365,19 @@ std::vector XlaOpRegistry::DeviceKernels( return ops; } +/*static*/ const std::unordered_set* +XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + static auto empty_set = new std::unordered_set; + if (it == registry.ops_.end() || it->second.empty()) { + return empty_set; + } else { + return &it->second.front()->compile_time_constant_inputs; + } +} + /* static */ Status XlaOpRegistry::CompileTimeConstantInputs( const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, std::vector* result) { @@ -385,21 +398,10 @@ std::vector XlaOpRegistry::DeviceKernels( compile_time_constant_inputs_from_attr.end())); compile_time_constant_inputs = &compile_time_constant_inputs_from_attr; } else { - const string& op = node_def.op(); - - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto it = registry.ops_.find(op); - if (it == registry.ops_.end() || it->second.empty()) { + compile_time_constant_inputs = + CompileTimeConstantInputArgNames(node_def.op()); + if (compile_time_constant_inputs->empty()) { return Status::OK(); - } else { - // The test in IsCompatible ensures that if there are multiple matching - // registrations for this op name, they all have the same value of - // compile_time_constant_inputs, so only the first match is returned. - // - // TODO(sanjoy): This can probably be a std::vector. - compile_time_constant_inputs = - &it->second.front()->compile_time_constant_inputs; } } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index af720fb4bb9..9533acb6a0c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -198,6 +198,11 @@ class XlaOpRegistry { /*op_def=*/nullptr, result); } + // Return names of arguments for a given op which are supposed to be + // constants. + static const std::unordered_set* + CompileTimeConstantInputArgNames(const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. static bool IsMetadataOp(const string& op); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 35fa6a617f0..598112e00df 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -969,6 +969,11 @@ tf_cc_test( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 392cd9bd359..a85d551769c 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -289,13 +289,19 @@ class Array { } // Fills the array with random normal variables with the specified mean. - void FillRandom(const T& stddev, const double mean = 0.0, - const int seed = 12345) { + void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) { + FillRandomDouble(static_cast(stddev), mean, seed); + } + + void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) { std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(stddev)); + std::normal_distribution distribution(mean, stddev); for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); + if (std::is_same()) { + values_[i] = static_cast(distribution(g) > 0.0); + } else { + values_[i] = static_cast(distribution(g)); + } } } diff --git a/tensorflow/compiler/xla/bit_cast.h b/tensorflow/compiler/xla/bit_cast.h index 90e9a5c25dd..feb548c9433 100644 --- a/tensorflow/compiler/xla/bit_cast.h +++ b/tensorflow/compiler/xla/bit_cast.h @@ -29,7 +29,7 @@ limitations under the License. #include "absl/base/casts.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 06fd8ceeb2b..6cd77bf9f19 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -55,9 +55,13 @@ xla_test( cc_library( name = "comparators", srcs = ["comparators.cc"], - hdrs = ["comparators.h"], + hdrs = [ + "comparators.h", + "//tensorflow/compiler/xla:literal_util", + ], deps = [ ":constants", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -195,6 +199,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -300,6 +305,20 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tensor_float_32_utils", + ], +) + +cc_library( + name = "lu_decomposition", + srcs = ["lu_decomposition.cc"], + hdrs = ["lu_decomposition.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -340,6 +359,9 @@ cc_library( hdrs = ["sorting.h"], deps = [ ":comparators", + ":constants", + ":loops", + ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -571,6 +593,7 @@ cc_library( ":loops", ":math", ":matrix", + ":qr", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 20d9930341f..744cdcea14c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -137,7 +137,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, arg_max = Select(eq, tie_id, arg_max); } Tuple(b, {max, arg_max}); - return b->Build().ConsumeValueOrDie(); + return b->BuildAndNoteError(); } XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min, diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index 74e89b767cf..c9d6cea740d 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -32,85 +32,13 @@ limitations under the License. namespace xla { namespace { -using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span); - -XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, - int64 bit_width) { - PrimitiveType signed_type; - PrimitiveType unsigned_type; - XlaOp max_value; - switch (bit_width) { - case 16: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S16; - unsigned_type = U16; - break; - case 32: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S32; - unsigned_type = U32; - break; - case 64: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S64; - unsigned_type = U64; - break; - default: - return value.builder()->ReportError( - InvalidArgument("Invalid bit width %lld for Comparator floating " - "point parameter.", - bit_width)); - } - // Switch from a floating point value to a integer value in such a way that - // when using the integer value to compare, we get the same result for normal - // values, and -Nan is treated as the smallest value, and Nan is treated as - // the largest value. - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? numeric_limits::max() - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - // Note that in order to avoid -x to overflow, we calculate - // numeric_limits::max() - x as unsigned, and then convert back to - // signed. - auto signed_value = BitcastConvertType(value, signed_type); - auto unsigned_value = BitcastConvertType(value, unsigned_type); - auto flipped_value = - BitcastConvertType(Sub(max_value, unsigned_value), signed_type); - auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type)); - return Select(is_negative, flipped_value, signed_value); -} - -void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param, - XlaOp* rhs_param) { - if (primitive_util::IsFloatingPointType(operand_type)) { - PrimitiveType compare_type = operand_type; - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - *lhs_param = ConvertElementType(*lhs_param, F32); - *rhs_param = ConvertElementType(*rhs_param, F32); - } - int64 bit_width = primitive_util::BitWidth(compare_type); - *lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width); - *rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width); - } -} +using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span); XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, - XlaBuilder* builder, XlaOpGenerator generator) { + XlaBuilder* builder, XlaCompareOp generator) { CHECK_NE(operand_types.size(), 0); - std::vector> generators(operand_types.size()); + std::vector> generators(operand_types.size()); generators[0] = generator; return CreateScalarComparisonComputation(name, operand_types, generators, builder); @@ -119,7 +47,7 @@ XlaComputation CreateScalarComparisonComputation( XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, - const std::vector>& generators, + const std::vector>& generators, XlaBuilder* builder) { // Create a default computation where we compare only the first two // parameters of type 'operand_types[0]'. @@ -146,7 +74,6 @@ XlaComputation CreateScalarComparisonComputation( absl::StrCat("p.", parameter_count, ".lhs")); auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, absl::StrCat("p.", parameter_count, ".rhs")); - ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param); lhs_params.emplace_back(lhs_param); rhs_params.emplace_back(rhs_param); if (generators[parameter_count].has_value()) { @@ -157,7 +84,12 @@ XlaComputation CreateScalarComparisonComputation( CHECK_NE(parameter_count, 0); - Shape shape = b->GetShape(lhs_params[0]).ValueOrDie(); + auto shape_or = b->GetShape(lhs_params[0]); + if (!shape_or.ok()) { + b->ReportError(shape_or.status()); + return {}; + } + Shape shape = shape_or.ValueOrDie(); shape.set_element_type(PRED); XlaOp param_equal = Broadcast(One(b.get(), shape.element_type()), AsInt64Slice(shape.dimensions())); @@ -169,7 +101,8 @@ XlaComputation CreateScalarComparisonComputation( generators[i].value()(lhs_params[i], rhs_params[i], {}), result); if (i != last_generator_index) { - param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i])); + param_equal = + And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i])); } } } @@ -181,14 +114,14 @@ XlaComputation CreateScalarComparisonComputation( XlaComputation CreateScalarLtComputation( const std::vector& operand_types, XlaBuilder* builder) { return CreateScalarComparisonComputation("compare-less-than", operand_types, - builder, Lt); + builder, LtTotalOrder); } // Creates a scalar greater-than computation and returns it. XlaComputation CreateScalarGtComputation( const std::vector& operand_types, XlaBuilder* builder) { - return CreateScalarComparisonComputation("compare-greater-than", - operand_types, builder, Gt); + return CreateScalarComparisonComputation( + "compare-greater-than", operand_types, builder, GtTotalOrder); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h index 25924d4a4f4..a82a84799aa 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.h +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -43,14 +43,13 @@ XlaComputation CreateScalarGtComputation( const std::vector& operand_types, XlaBuilder* builder); // Creates a scalar comparison computation and returns it. This function takes -// an std::vector> and compare the operands -// where the generator isn't nullopt with the specified comparator -// at that location. +// a vector of comparator functions to compare the operands where the function +// isn't nullopt with the specified comparator at that location. XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, const std::vector< absl::optional)>>& - generators, + comparators, XlaBuilder* builder); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc index 8f37c393922..18cd0870f2a 100644 --- a/tensorflow/compiler/xla/client/lib/logdet.cc +++ b/tensorflow/compiler/xla/client/lib/logdet.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -33,13 +34,46 @@ limitations under the License. namespace xla { -// let G = root(A) be the Cholesky root of the matrix A -// log(det(A)) = 2*sum(log(vecdiag(G))) +// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is +// orthonormal XlaOp LogDet(XlaOp a) { - XlaOp cholesky = Cholesky(a, /*bool lower=*/true); + return a.builder()->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a)); + // Compute the number of Householder transformations required on 'a' by + // determining the number of rows in 'a' that are already triangular. The + // determinant of Q is -1 ^ (number of Householder transfomations) + auto rows = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32), + a_shape.rank() - 2); + auto cols = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32), + a_shape.rank() - 1); + auto in_lower_triangle = Lt(cols, rows); + auto is_zero = Eq(a, ScalarLike(a, 0)); + auto num_zeros_in_triangle_per_row = Einsum( + ConvertElementType(And(in_lower_triangle, is_zero), S32), "...a->..."); + TF_ASSIGN_OR_RETURN(auto row_shape, + a.builder()->GetShape(num_zeros_in_triangle_per_row)); + rows = Iota(a.builder(), row_shape, row_shape.rank() - 1); + auto num_triangle_rows = + Einsum(ConvertElementType(Eq(rows, num_zeros_in_triangle_per_row), S32), + "...a->..."); + auto num_rows = + ScalarLike(num_triangle_rows, a_shape.dimensions(a_shape.rank() - 2)); - return ScalarLike(a, 2) * - Einsum(Log(cholesky), "...aa->...", xla::PrecisionConfig::HIGHEST); + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, true)); + // Get the and log of the determinant based on the values along the diagonal + // of R. + auto log_abs_det = Einsum(Log(Abs(qr.r)), "...aa->..."); + auto sign_diag = Reduce( + Sign(Einsum(qr.r, "...aa->...a")), + One(a.builder(), a_shape.element_type()), + CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()), + {a_shape.rank() - 2}); + return sign_diag * log_abs_det * + Select(ConvertElementType(Rem(num_rows - num_triangle_rows, + ScalarLike(num_triangle_rows, 2)), + PRED), + ScalarLike(sign_diag, -1.0), ScalarLike(sign_diag, 1.0)); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc index 54af41f77f6..319d819ed98 100644 --- a/tensorflow/compiler/xla/client/lib/logdet_test.cc +++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc @@ -51,6 +51,26 @@ XLA_TEST_F(LogDetTest, Simple) { xla::ErrorSpec(1e-4)); } +XLA_TEST_F(LogDetTest, SimpleTriangle) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {4, -39, 62, 73}, + {0, 0, -146, 166}, + {4, 6, 8, 320}, + }); + + float expected = -15.9131355f; + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::LogDet(a); + + ComputeAndCompareR0(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); +} + XLA_TEST_F(LogDetTest, SimpleBatched) { xla::XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/lib/lu_decomposition.cc b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc new file mode 100644 index 00000000000..2920b6f56b5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +LuDecompositionResult LuDecomposition(XlaOp a) { + XlaBuilder* builder = a.builder(); + XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = a_shape.rank(); + TF_RET_CHECK(ndims >= 2); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + const int num_batch_dims = a_shape.dimensions().size() - 2; + const std::vector batch_dims( + a_shape.dimensions().begin(), + a_shape.dimensions().begin() + num_batch_dims); + + std::vector pivot_dims = batch_dims; + pivot_dims.push_back(std::min(m, n)); + std::vector perm_dims = batch_dims; + perm_dims.push_back(m); + Shape lu_shape = ShapeUtil::MakeTupleShape( + {a_shape, ShapeUtil::MakeShape(S32, pivot_dims), + ShapeUtil::MakeShape(S32, perm_dims)}); + // The TPU compiler has a rewrite pass that lowers an LuDecomposition + // CustomCall. + // TODO(phawkins): upgrade LU decomposition to a first-class HLO operator + // and implement it on other backends. + return CustomCall(a.builder(), "LuDecomposition", {a}, lu_shape); + }); + return LuDecompositionResult{GetTupleElement(result, 0), + GetTupleElement(result, 1), + GetTupleElement(result, 2)}; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/lu_decomposition.h b/tensorflow/compiler/xla/client/lib/lu_decomposition.h new file mode 100644 index 00000000000..3f5703510a3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/lu_decomposition.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Computes the LU decomposition with partial pivoting of a batch of matrices. +// +// Given a (batched) matrix a with shape [..., m, n], computes the matrix +// decomposition A = P @ L @ U where P is a permutation matrix, L is a +// lower-triangular matrix with unit diagonal entries, and U is an +// upper-triangular matrix. +// +// L and U are returned as a single matrix [..., m, n] containing both L and U +// packed in the same array. The unit diagonal of L is not represented +// explicitly. +// +// The permutation matrix P is returned in two forms, both as `pivots`, which is +// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the +// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array +// which gives the permutation to apply to the rows. We return both +// representations because they are each useful for different purposes; `pivots` +// is useful for computing the sign of a determinant, whereas `permutation` can +// be used via a Gather operation to permute the rows of a matrix. +// +// This method is only implemented on TPU at the moment. +// TODO(b/168208200): the implementation only supports F32 arrays. Handle the +// complex case. +struct LuDecompositionResult { + // The LU decomposition, with both L and U packed into an array with shape + // [..., m, n]. + XlaOp lu; + // An array of shape s32[..., min(m, n)] containing the pivot rows. + XlaOp pivots; + // An array of shape s32[..., m], containing an another representation of the + // pivots as a permutation. + XlaOp permutation; +}; + +LuDecompositionResult LuDecomposition(XlaOp a); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 6fdaab58686..cd9f88a74ce 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1111,11 +1111,28 @@ XlaOp RoundToEven(XlaOp x) { // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 // pi if x == -1 +// For complex: +// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) XlaOp Acos(XlaOp x) { - return Select(Ne(x, FullLike(x, -1)), - ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), - ScalarLike(x, 1.0) + x), - FullLike(x, M_PI)); + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + + if (primitive_util::IsComplexType(shape.element_type())) { + auto one = ScalarLike(x, 1); + auto imag_one = Complex( + Zero(b, primitive_util::ComplexComponentType(shape.element_type())), + One(b, primitive_util::ComplexComponentType(shape.element_type()))); + + auto result = + Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x)))); + return result; + } + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); + }); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index cb79b2ef7db..ae4d839d8fa 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -660,5 +660,19 @@ XLA_TEST_F(MathTest, BesselI1eDouble) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, AcosComplexValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1>( + &builder, {{0, 0}, {0, 1}, {1, 1}, {0.8, 0.2}}); + + Acos(x); + std::vector> expected = { + {1.5707963267948966, 0}, + {1.5707963267948966, -0.881373587019543}, + {0.9045568943023814, -1.0612750619050357}, + {0.7011246914497526, -0.30527648462436596}}; + ComputeAndCompareR1>(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index b7721f2bbc5..dbb73602801 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -235,85 +236,93 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } namespace { -std::vector EinsumDiagonalLabels(absl::Span config) { +absl::optional, 3>> EinsumDiagonalLabels( + absl::Span config) { std::vector unique_labels; + std::vector reduce_dims; + std::vector broadcast_dims; for (auto label = config.begin(); label != config.end(); ++label) { auto first_label = absl::c_find(config, *label); + auto dim = label - config.begin(); if (first_label == label) { unique_labels.push_back(*label); + broadcast_dims.push_back(dim); + } else { + reduce_dims.push_back(dim); } } if (unique_labels.size() == config.size()) { - unique_labels.clear(); + return absl::nullopt; } - return unique_labels; + return {{unique_labels, reduce_dims, broadcast_dims}}; } -} // namespace -xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { +// Masks a tensor such that only the diagonal of repeated indices are non-zero. +// The result of this can be used to create a diagonal matrix with an identity +// reduction. +xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span config) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - if (EinsumDiagonalLabels(config).empty()) { - return x; - } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); Shape iota_shape = x_shape; iota_shape.set_element_type(S32); XlaOp mask = ConstantR0(builder, true); - absl::InlinedVector reduce_dims; for (auto label = config.begin(); label != config.end(); ++label) { const int64 dim = label - config.begin(); auto first_label = absl::c_find(config, *label); - if (first_label == label) { - continue; + if (first_label != label) { + const int64 first_dim = first_label - config.begin(); + mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), + Iota(builder, iota_shape, dim))); } - reduce_dims.push_back(dim); - const int64 first_dim = first_label - config.begin(); - mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), - Iota(builder, iota_shape, dim))); } - auto zero = ScalarLike(x, 0); - return Reduce(Select(mask, x, zero), zero, - CreateScalarIdentityWithZeroComputation( - x_shape.element_type(), builder), - reduce_dims); + return Select(mask, x, ZerosLike(x)); }); } -Status ValidateEinsumNumericDimensions(absl::Span x_config, - absl::Span y_config, - absl::Span output_config) { - for (auto dim : output_config) { - if (absl::c_linear_search(x_config, dim) || - absl::c_linear_search(y_config, dim)) { - if (absl::c_count(output_config, dim) > 1) { - return InvalidArgument("Einsum has repeated output dimension."); - } - continue; +xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto labels = EinsumDiagonalLabels(config); + if (!labels) { + return x; } - return InvalidArgument( - "Einsum has output dimension without corresponding input dimension."); - } - for (auto dim : x_config) { - if (absl::c_linear_search(y_config, dim) || - absl::c_linear_search(output_config, dim)) { - if (absl::c_count(x_config, dim) > 1) { - return InvalidArgument("Einsum has repeated lhs dimension."); - } - } - } - for (auto dim : y_config) { - if (absl::c_linear_search(x_config, dim) || - absl::c_linear_search(output_config, dim)) { - if (absl::c_count(y_config, dim) > 1) { - return InvalidArgument("Einsum has repeated rhs dimension."); - } - } - } - return Status::OK(); + auto zero = ScalarLike(x, 0); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + return Reduce(EinsumDiagonalMask(x, config), zero, + CreateScalarIdentityWithZeroComputation( + x_shape.element_type(), builder), + labels->at(1)); + }); } +xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto labels = EinsumDiagonalLabels(config); + if (!labels) { + return x; + } + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + std::vector broadcast_sizes; + int64 x_dim = 0; + for (auto label = config.begin(); label != config.end(); ++label) { + auto first_label = absl::c_find(config, *label); + if (first_label == label) { + broadcast_sizes.push_back(x_shape.dimensions(x_dim)); + ++x_dim; + } else { + broadcast_sizes.push_back( + broadcast_sizes[first_label - config.begin()]); + } + } + x = BroadcastInDim(x, broadcast_sizes, labels->at(2)); + return EinsumDiagonalMask(x, config); + }); +} +} // namespace + namespace { // Helper method to remove dimensions from a shape and dot dimension numbers // used to implement implicit broadcasting. @@ -347,21 +356,23 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { auto x_diagonal_labels = EinsumDiagonalLabels(x_config); + if (x_diagonal_labels) { + return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y, + y_config, output_config, precision); + } auto y_diagonal_labels = EinsumDiagonalLabels(y_config); - if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) { - return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, - EinsumDiagonal(y, y_config), y_diagonal_labels, - output_config, precision); - } else if (!x_diagonal_labels.empty()) { - return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config, - output_config, precision); - } else if (!y_diagonal_labels.empty()) { - return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels, - output_config, precision); + if (y_diagonal_labels) { + return Einsum(x, x_config, EinsumDiagonal(y, y_config), + y_diagonal_labels->at(0), output_config, precision); + } + auto output_diagonal_labels = EinsumDiagonalLabels(output_config); + if (output_diagonal_labels) { + return EinsumInverseDiagonal( + Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0), + precision), + output_config); } - TF_RETURN_IF_ERROR( - ValidateEinsumNumericDimensions(x_config, y_config, output_config)); TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); const int64 x_rank = x_config.size(); @@ -372,41 +383,37 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, absl::flat_hash_set output_map; for (auto d : x_config) { - if (!x_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support rhs tracing"); - } + x_map.insert(d); } for (auto d : y_config) { - if (!y_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support lhs tracing"); - } + y_map.insert(d); } for (auto d : output_config) { - if (!output_map.insert(d).second) { - return InvalidArgument("XLA Einsum does not support output tracing"); - } + output_map.insert(d); } DotDimensionNumbers dnums; - std::vector lhs_outer_dims; auto is_batch_dim = [&](int64 d) { return x_map.contains(d) && y_map.contains(d) && output_map.contains(d); }; auto is_contracting = [&](int64 d) { return x_map.contains(d) && y_map.contains(d); }; + auto rhs_dimension_number = [&](int64 d) { return absl::c_find(y_config, d) - y_config.begin(); }; absl::InlinedVector rhs_outer_dims; + absl::InlinedVector lhs_outer_dims; absl::InlinedVector rhs_delete_dims; absl::InlinedVector lhs_delete_dims; for (int64 i = 0; i < x_rank; ++i) { auto dim_name = x_config[i]; const int64 rhs_dim = rhs_dimension_number(dim_name); + if (is_batch_dim(dim_name)) { if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) { dnums.add_lhs_batch_dimensions(i); @@ -442,63 +449,90 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } absl::c_sort(rhs_outer_dims); - absl::InlinedVector output_transpose_dims; - absl::InlinedVector output_reduce_dims; - auto output_dimension_number = [&](int64 d) { + + auto output_dimension_number = [&](int64 d) -> absl::optional { auto pos = absl::c_find(output_config, d); if (pos == output_config.end()) { - const int64 dim = - output_transpose_dims.size() + output_reduce_dims.size(); - output_reduce_dims.push_back(dim); - } else { - output_transpose_dims.push_back(pos - output_config.begin()); + return absl::nullopt; } + return pos - output_config.begin(); }; for (auto d : dnums.lhs_batch_dimensions()) { - output_dimension_number(x_config[d]); + output_transpose_dims.push_back(*output_dimension_number(x_config[d])); } for (auto d : lhs_outer_dims) { - output_dimension_number(x_config[d]); + if (auto output_dim = output_dimension_number(x_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + lhs_delete_dims.push_back(d); } for (auto d : rhs_outer_dims) { - output_dimension_number(y_config[d]); + if (auto output_dim = output_dimension_number(y_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + rhs_delete_dims.push_back(d); } + const int64 transpose_rank = output_transpose_dims.size(); std::vector transpose_dims(output_rank); - for (int64 i = 0; i < output_rank; ++i) { + for (int64 i = 0; i < transpose_rank; ++i) { transpose_dims[output_transpose_dims[i]] = i; } // Remove ones that where broadcasted from the x and the y shape and adjust // the dimension numbers that are more minor than those dimensions. + absl::c_sort(lhs_delete_dims); DeleteDimsFromContainer(lhs_delete_dims, &x_shape, dnums.mutable_lhs_batch_dimensions(), dnums.mutable_lhs_contracting_dimensions()); + + absl::c_sort(rhs_delete_dims); DeleteDimsFromContainer(rhs_delete_dims, &y_shape, dnums.mutable_rhs_batch_dimensions(), dnums.mutable_rhs_contracting_dimensions()); if (!lhs_delete_dims.empty()) { - x = Reshape(x, x_shape.dimensions()); + x = Reduce(x, ScalarLike(x, 0), + CreateScalarAddComputation(x_shape.element_type(), builder), + lhs_delete_dims); } if (!rhs_delete_dims.empty()) { - y = Reshape(y, y_shape.dimensions()); + y = Reduce(y, ScalarLike(y, 0), + CreateScalarAddComputation(y_shape.element_type(), builder), + rhs_delete_dims); } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto dot = DotGeneral(x, y, dnums, &precision_proto); - if (!output_reduce_dims.empty()) { - dot = Reduce(dot, ScalarLike(dot, 0), - CreateScalarAddComputation(x_shape.element_type(), builder), - output_reduce_dims); + dot = Transpose(dot, transpose_dims); + if (transpose_rank == output_rank) { + return dot; } - return Transpose(dot, transpose_dims); + + auto is_output_only = [&](int64 d) { + return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d); + }; + + int64 dot_dim = 0; + std::vector new_dims; + new_dims.reserve(output_rank); + TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot)); + for (auto d : output_config) { + if (is_output_only(d)) { + new_dims.push_back(1); + } else { + new_dims.push_back(dot_shape.dimensions(dot_dim)); + } + } + return Reshape(dot, new_dims); }); } diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 46f70ed27b9..1a9f72dedf2 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -112,14 +112,6 @@ StatusOr, 3>> ParseEinsumString( // Returns an empty string if the einsum string already has an ->. std::string NormalizeEinsumString(absl::string_view einsum_config); -// Determine if each dimension label is in at least two inputs. -// -// NOTE: This function is meant for testing, there is no need to call it -// directly. -Status ValidateEinsumNumericDimensions(absl::Span x_config, - absl::Span y_config, - absl::Span output_config); - // Supports two operand einsum notation like "ab,cb->ac". xla::XlaOp Einsum( xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, @@ -128,9 +120,6 @@ xla::XlaOp Einsum( xla::XlaOp x, absl::string_view einsum_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); -// Handles repeated indices within an operand by taking the tensor diagonal of -// the input. -xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config); // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" // becomes: diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index ebbf39ec096..628447c289e 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -233,12 +233,23 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { }; std::vector> good_test_cases = { - {"ab", "bc", "ac"}, {"Bab", "Bbc", "Bac"}, - {"ab", "cd", "dcba"}, {"abc", "abd", "cbd"}, - {"...ab", "...bc", "...ac"}, {"a...bc", "...abd", "cbd..."}, - {"...ab", "...bc", "ac"}, {"...b", "...bc", "...c"}, - {"...abz", "...bc", "...ac"}, {"...ab", "...bcz", "...ac"}, - {"abz", "bc", "ac"}, {"ab", "bcz", "ac"}, + {"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}, + {"...ab", "...bc", "...ac"}, + {"a...bc", "...abd", "cbd..."}, + {"...ab", "...bc", "ac"}, + {"...b", "...bc", "...c"}, + {"...abz", "...bc", "...ac"}, + {"...ab", "...bcz", "...ac"}, + {"abz", "bc", "ac"}, + {"ab", "bcz", "ac"}, + + {"a", "b", "c"}, + {"...a", "...b", "...c"}, + {"abb", "bcc", "ac"}, + {"ab", "bc", "ad"}, }; for (auto test_case : good_test_cases) { auto parse_result_or_status = @@ -249,9 +260,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { for (int i = 0; i < 3; ++i) { EXPECT_EQ(parse_result[i], to_vec(test_case[i])); } - EXPECT_TRUE(ValidateEinsumNumericDimensions( - parse_result[0], parse_result[1], parse_result[2]) - .ok()); } std::vector einsum_strings_that_fail_parsing = { @@ -261,24 +269,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { auto parse_result_or_status = ParseEinsumString(test_case, 3, 3); EXPECT_FALSE(parse_result_or_status.status().ok()); } - std::vector> einsum_strings_that_fail_numeric_validation = - { - {"a", "b", "c"}, - {"...a", "...b", "...c"}, - {"abb", "bcc", "ac"}, - {"ab", "bc", "ad"}, - }; - - for (auto test_case : einsum_strings_that_fail_numeric_validation) { - auto parse_result_or_status = - ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]), - test_case[0].size(), test_case[1].size()); - EXPECT_TRUE(parse_result_or_status.status().ok()); - auto parse_result = parse_result_or_status.ValueOrDie(); - EXPECT_FALSE(ValidateEinsumNumericDimensions( - parse_result[0], parse_result[1], parse_result[2]) - .ok()); - } } XLA_TEST_F(MatrixTest, NormalizeEinsumString) { diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 044a742eddd..60086773d18 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -426,32 +426,36 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, XlaOp maxval) { XlaBuilder* builder = bits.builder(); - PrimitiveType value_type = - builder->GetShape(minval).ConsumeValueOrDie().element_type(); - PrimitiveType bit_type = - builder->GetShape(bits).ConsumeValueOrDie().element_type(); - CHECK((value_type == F32 && bit_type == U32) || - (value_type == F64 && bit_type == U64)); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* minval_shape, + builder->GetShapePtr(minval)); + TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits)); + PrimitiveType value_type = minval_shape->element_type(); + PrimitiveType bit_type = bits_shape->element_type(); + CHECK((value_type == F32 && bit_type == U32) || + (value_type == F64 && bit_type == U64)); - // Form random mantissa bits for float/double, with a leading 1 bit. - int num_float_bits = primitive_util::BitWidth(value_type); - // Subtract one as SignificandWidth includes the leading 1 bit. - int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; + // Form random mantissa bits for float/double, with a leading 1 bit. + int num_float_bits = primitive_util::BitWidth(value_type); + // Subtract one as SignificandWidth includes the leading 1 bit. + int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; - // Ignore the exponent bits and convert the mantissa bits to the floating - // point type. - bits = ShiftRightLogical( - bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); + // Ignore the exponent bits and convert the mantissa bits to the floating + // point type. + bits = ShiftRightLogical( + bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); - // We have an integer-valued floating point number in the range - // [0, 2**{num_mantissa_bits}). - XlaOp values = ConvertElementType(bits, value_type); + // We have an integer-valued floating point number in the range + // [0, 2**{num_mantissa_bits}). + XlaOp values = ConvertElementType(bits, value_type); - // Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0). - values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); + // Divide by 2**{-num_mantissa_bits} to get a number in the range + // [0.0, 1.0). + values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); - // Multiply and add to shift to the range [minval, maxval). - return values * (maxval - minval) + minval; + // Multiply and add to shift to the range [minval, maxval). + return values * (maxval - minval) + minval; + }); } XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, @@ -483,6 +487,10 @@ std::pair BoxMullerTransform(XlaOp x0, XlaOp x1) { } // namespace +XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) { + return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta)); +} + RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 107fd884de3..20ad223403d 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -89,6 +89,9 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, absl::Span scalars); +// Increases Philox counter (an uint128) by a delta (an uint64). +xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index b2eecbac309..09fa465a865 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -127,29 +127,24 @@ Status House(XlaOp x, XlaOp k, absl::Span batch_dims, // def qr(a): // m = a.shape[0] // n = a.shape[1] -// vs = np.zeros([m, n]) // taus = np.zeros([n]) // for j in xrange(min(m, n)): // v, tau, beta = house(a[:, j], j) -// # Unusually, we apply the Householder transformation to the entirety of -// # a, wasting FLOPs to maintain the static shape invariant that XLA -// # requires. For columns that precede j this has no effect. -// a[:, :] -= tau * np.dot(v[:, np.newaxis], -// np.dot(v[np.newaxis, :], a[:, :])) +// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis], +// np.dot(v[np.newaxis, :], a[:, j+1:])) // # Form column j explicitly rather than relying on the precision of the // # Householder update. // a[j, j] = beta -// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) -// vs[:, j] = v +// a[j+1:, j] = v[j+1:] // taus[j] = tau -// return (q, vs, taus) +// return (a, taus) struct QRBlockResult { - // The factored R value - XlaOp r; + // The upper-triangular matrix R, packed together with the lower-triangular + // elementary Householder reflectors `vs` below the diagonal. + XlaOp a; // Representation of the Householder matrices I - beta v v.T XlaOp taus; // Shape: [..., n] - XlaOp vs; // Shape: [..., m, n] }; StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); @@ -176,57 +171,52 @@ StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { auto qr_body_fn = [&](XlaOp j, absl::Span values, XlaBuilder* builder) -> StatusOr> { auto a = values[0]; - auto vs = values[1]; - auto taus = values[2]; + auto taus = values[1]; - // v, beta = house(a[:, j], j) + // v, tau, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); XlaOp v, tau, beta; TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); + const int64 minor_dim = batch_dims.size(); + auto iota_mn = Iota( + builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), + minor_dim + 1); + std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); auto v_broadcast = Reshape(v, shape); - // a[:, :] -= tau * np.dot(v[:, np.newaxis], - // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a, precision); + // a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:])) + // We use masking rather than a loop-variant shape to handle the j+1: + // indexing. + auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)), + precision); vva = BatchDot(v_broadcast, true, vva, false, precision); a = a - Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); - // It is more precise to populate column 'k' explicitly, rather than - // computing it implicitly by applying the Householder transformation. - // a[k,k] = beta - // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) + // a[j, j] = beta + // a[j+1:,j] = v[j+1:] auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); auto predecessor_mask = ConvertElementType(Lt(iota, j), type); auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), std::vector(batch_dims.size(), 1)); + auto successor_mask = Gt(Iota(a.builder(), S32, m), j); auto new_x = Mul(x, predecessor_mask, /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + new_x = Add( + new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)), + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim})); // Update a[:,j] std::vector dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), /*broadcast_dimensions=*/dim_ids); - const int64 minor_dim = batch_dims.size(); - auto iota_mn = Iota( - builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), - minor_dim + 1); a = Select(Eq(iota_mn, j), new_x, a); - // vs[:, j] = v - std::vector vs_broadcast_dims(batch_dims.size() + 1); - std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0); - auto vs_zeros = ZerosLike(vs); - auto vs_update = Select( - Eq(iota_mn, j), - Add(vs_zeros, v, /*broadcast_dimensions=*/vs_broadcast_dims), vs_zeros); - vs = vs + vs_update; - // taus[j] = tau std::vector tau_broadcast_dims(batch_dims.size()); std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); @@ -240,40 +230,38 @@ StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), taus_zeros); taus = taus + taus_update; - return std::vector{a, vs, taus}; + return std::vector{a, taus}; }; - auto vs = Zeros( - builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto taus = Zeros(builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + {a, taus}, "qr", builder)); QRBlockResult result; - result.r = values[0]; - result.vs = values[1]; - result.taus = values[2]; + result.a = values[0]; + result.taus = values[1]; return result; } -// Computes W and Y such that I-WY is equivalent to the sequence of Householder -// transformations given by vs and taus. -// Golub and van Loan, "Matrix Computations", algorithm 5.1.2. -// Y = np.zeros([m, n]) -// W = np.zeros([m, n]) -// Y[:, 0] = vs[:, 0] -// W[:, 0] = -taus[0] * vs[:, 0] -// for j in xrange(1, n): -// v = vs[:, j] -// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) -// W[:, j] = z -// Y[:, j] = v -// return W -// There is no need to return Y since at termination of the loop it is equal to -// vs. -StatusOr ComputeWYRepresentation(PrimitiveType type, +// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary +// Householder reflectors given by `vs` and `taus`. +// +// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY +// representation for products of Householder transformations." SIAM Journal on +// Scientific and Statistical Computing 10.1 (1989): 53-57. +// +// def compact_wy(vs, taus): +// m, n = vs.shape[-2:] +// t = np.eye(n) * -taus +// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix +// # multiplication to many matrix-vector products. +// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n) +// for i in range(1, n): +// t[:, i] = np.dot(t, vtv[:, i]) +// return t +StatusOr CompactWYRepresentation(PrimitiveType type, absl::Span batch_dims, XlaOp vs, XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) { @@ -284,50 +272,38 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, auto body_fn = [&](XlaOp j, absl::Span values, XlaBuilder* builder) -> StatusOr> { // w has shape [..., m, n] - auto w = values[0]; - const auto vs = values[1]; - const auto taus = values[2]; + auto t = values[0]; + const auto vtv = values[1]; // Want j values in range [1, ... n). j = j + ConstantR0(builder, 1); - // vs has shape [..., m, 1] - auto v = DynamicSliceInMinorDims(vs, {j}, {1}); - // beta has shape [..., 1] - auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); - - auto iota_mn = Iota( - builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), - n_index); - - // y has shape [..., m, n] - auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs); // yv has shape [..., n, 1] - auto yv = BatchDot(y, true, v, false, precision); - // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv, precision); + auto yv = DynamicSliceInMinorDims(vtv, {j}, {1}); - auto z = Mul( - -beta, v + wyv, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + // wyv has shape [..., n, 1] + auto z = BatchDot(t, yv, precision); - w = DynamicUpdateSliceInMinorDims(w, z, {j}); + t = DynamicUpdateSliceInMinorDims(t, z, {j}); - return std::vector{w, vs, taus}; + return std::vector{t, vtv}; }; XlaBuilder* builder = vs.builder(); - auto w = Zeros(builder, - ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); - auto v = SliceInMinorDims(vs, {0}, {1}); - auto beta = SliceInMinorDims(taus, {0}, {1}); - auto bv = - Mul(-beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); - w = UpdateSliceInMinorDims(w, bv, {0}); - TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn, - {w, vs, taus}, "wy", builder)); + auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}), + ConcatVectors(batch_dim_indices, {n_index})); + + auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims); + auto t = eye * tau_scale; + + auto vtv = + BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision); + vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv) * tau_scale; + vtv = vtv + eye; + + TF_ASSIGN_OR_RETURN( + auto values, ForEachIndex(n - 1, S32, body_fn, {t, vtv}, "wy", builder)); return values[0]; } @@ -340,14 +316,12 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, // q = np.eye(m) // for i in xrange(0, min(m, n), block_size): // k = min(block_size, min(m, n) - s) -// (a, vs, taus) = qr(a[i:, i:i+k]) -// y = vs -// w = ComputeWYRepresentation(vs, taus, m-i, k) -// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) -// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) +// (a, taus) = qr(a[i:, i:i+k]) +// y = np.eye(m, n) + np.tril(a, -1) +// t = CompactWYRepresentation(vs, taus, m-i, k) +// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) +// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T // return (q, a) -// TODO(phawkins): consider using UT transformations (in the form I - V U V') -// rather than WY transformations. StatusOr QRDecomposition( XlaOp a, bool full_matrices, int64 block_size, PrecisionConfig::Precision precision) { @@ -381,27 +355,34 @@ StatusOr QRDecomposition( auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); + auto y = Add( + IdentityMatrix(builder, type, m - i, k), + Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)), + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}); - a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); + a = UpdateSliceInMinorDims(a, qr_block.a, {i, i}); - // Compute the I-WY block representation of a product of Householder - // matrices. + // Compute the I + Y @ T @ Y^t block representation of a product of + // Householder matrices. TF_ASSIGN_OR_RETURN( - auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k, precision)); - auto y = qr_block.vs; + auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus, + m - i, k, precision)); - // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) + // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) + auto yt = + BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision); auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, true, a_panel, false, precision); - a_update = BatchDot(y, a_update, precision); + auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel, + /*transpose_y=*/false, precision); + a_update = BatchDot(yt, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); - // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) + // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w, precision); - q_update = BatchDot(q_update, false, y, true, precision); + auto q_update = BatchDot(q_panel, y, precision); + q_update = BatchDot(q_update, /*transpose_x=*/false, yt, + /*transpose_y=*/true, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } @@ -414,7 +395,7 @@ StatusOr QRDecomposition( a = SliceInMinorDims(a, {0, 0}, {p, n}); } result.q = q; - result.r = a; + result.r = UpperTriangle(a); return result; } diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index a61f243e126..f1d2e4ddb1c 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -27,12 +27,15 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" namespace { using QrTest = xla::ClientLibraryTestBase; XLA_TEST_F(QrTest, Simple) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -61,6 +64,8 @@ XLA_TEST_F(QrTest, Simple) { } XLA_TEST_F(QrTest, ZeroDiagonal) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -88,6 +93,8 @@ XLA_TEST_F(QrTest, ZeroDiagonal) { } XLA_TEST_F(QrTest, SimpleBatched) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array3D a_vals({ diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h index 26dbbd5b00b..320dfcbf062 100644 --- a/tensorflow/compiler/xla/client/lib/quantize.h +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index 1c0680b883a..58905e4ca6f 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -228,7 +228,7 @@ StatusOr> WhileLoopFn( auto max_sweeps = ScalarLike(k, max_sweep_updates); auto sweep_update_cond = Gt(max_sweeps, k); - auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + TF_ASSIGN_OR_RETURN(auto norms, ComputeFrobeniusNorms(values[2])); auto tol = norms.total_norm * values[3]; auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), xla::ConstantR0(cond_builder, false), @@ -400,7 +400,7 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, return result; }; auto shape_with_status = builder->GetShape(a); - if (!shape_with_status.status().ok()) { + if (!shape_with_status.ok()) { return return_error(shape_with_status.status()); } Shape a_shape = shape_with_status.ValueOrDie(); @@ -450,7 +450,7 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, S32, // "CyclicJacobi", // builder); - if (!output_with_status.status().ok()) { + if (!output_with_status.ok()) { return return_error(output_with_status.status()); } @@ -460,7 +460,11 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, result.v = output[1]; result.w = GetMatrixDiagonal(output[2]); - return SortByEigenvalues(result).ValueOrDie(); + auto result_or = SortByEigenvalues(result); + if (!result_or.ok()) { + return return_error(result_or.status()); + } + return result_or.ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8e2e713c45c..10e27285f02 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -206,10 +206,12 @@ XLA_TEST_F(SlicingTest, DoubleEmptyIndexSelect) { xla::XlaOp input, index; Literal l(ShapeUtil::MakeShape(F32, {0, 1, 2, 0})); Literal i(ShapeUtil::MakeShape(S32, {0})); - auto input_data = - CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); - auto index_data = - CreateParameterAndTransferLiteral(1, i, "index", &builder, &index); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input)); + TF_ASSERT_OK_AND_ASSIGN( + auto index_data, + CreateParameterAndTransferLiteral(1, i, "index", &builder, &index)); TorchIndexSelect(input, index, 0); ComputeAndCompareLiteral(&builder, l, {input_data.get(), index_data.get()}); } @@ -219,8 +221,9 @@ XLA_TEST_F(SlicingTest, EmptyIndexSelectNonZero) { xla::XlaOp input, index; Literal l(ShapeUtil::MakeShape(F32, {0, 2})); - auto input_data = - CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input)); auto index_data = CreateR1Parameter({0, 0, 0}, 1, "index", &builder, &index); TorchIndexSelect(input, index, 0); diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 750237c2000..abb0054558f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -27,6 +30,20 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; + int64 last_dim_size = input_shape.dimensions(last_dim); + // TODO(b/148796364): tune these constants for better performance. + const int64 kPerPartitionSize = 8192; // 2^13 + const int64 kLastDimSizeThreshold = 524288; // 2^19 + const int64 kMinNumPartitions = 8; + const int64 kMinimalK = 1000; + if ((k >= kMinimalK) && (k < kPerPartitionSize) && + (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) { + int64 num_partitions = + CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); + if (num_partitions >= kMinNumPartitions) { + return TopKWithPartitions(input, k, num_partitions); + } + } Shape iota_shape = ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); @@ -80,30 +97,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { } } - XlaOp values, indices; - for (int64 partition = 0; partition < num_partitions; partition++) { - std::vector start_indices(input_shape.dimensions_size(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); - start_indices[last_dim] = partition * per_partition_size; - limit_indices[last_dim] = - std::min((partition + 1) * per_partition_size, last_dim_size); - // Slice value and indices for this partition.. - XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + auto topk_body_fn = + [&](XlaOp partition, absl::Span values_and_indices, + XlaBuilder* builder) -> StatusOr> { + auto values = values_and_indices[0]; + auto indices = values_and_indices[1]; + auto input = values_and_indices[2]; + auto iota_s32 = values_and_indices[3]; + + // Slice value and indices for this partition. + XlaOp start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + XlaOp sliced_input = + DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = - Slice(iota_s32, start_indices, limit_indices, strides); + DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size}); // Concat with previous results. - if (partition > 0) { - sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); - sliced_indices = - ConcatInDim(builder, {indices, sliced_indices}, last_dim); - } + sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); + sliced_indices = + ConcatInDim(builder, {indices, sliced_indices}, last_dim); // Sort this slice XlaOp sort_result = Sort({sliced_input, sliced_indices}, CreateScalarGtComputation({input_shape.element_type(), S32}, sliced_indices.builder()), - last_dim, /*is_stable=*/true); + last_dim, true); + + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); // Slice topk. start_indices[last_dim] = 0; limit_indices[last_dim] = k; @@ -111,8 +133,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { limit_indices, strides); indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); - } - return Tuple(builder, {values, indices}); + return std::vector{values, indices, input, iota_s32}; + }; + + // Get the values and indices for the first topk so that they can + // be passed to the while loop. + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); + start_indices[last_dim] = 0; + limit_indices[last_dim] = per_partition_size; + // Slice value and indices for the first partition. + XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + XlaOp sliced_indices = + Slice(iota_s32, start_indices, limit_indices, strides); + // Sort this slice + XlaOp sort_result = + Sort({sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), S32}, + sliced_indices.builder()), + last_dim, /*is_stable=*/true); + + // Slice topk. + start_indices[last_dim] = 0; + limit_indices[last_dim] = k; + XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + + // Pass the result of the first TopK to the while loop and do + // num_partition - 1 iterations. + TF_ASSIGN_OR_RETURN(auto values_and_indices, + ForEachIndex(num_partitions - 1, S32, topk_body_fn, + {values, indices, input, iota_s32}, + "topk_with_partition", builder)); + return Tuple(builder, {values_and_indices[0], values_and_indices[1]}); }); } diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index e01f6faf59e..e820d5bfe6f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) { ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); } +XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) { + XlaBuilder builder(TestName()); + Array input({2, 1000000}); + input.FillRandom(1.0f, 2.0f); + auto x = + CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder); + Array2D expected_array(2, 1000); + expected_array.Fill(2.0f); + xla::GetTupleElement(xla::TopK(x, 1000), 0); + ErrorSpec error_spec(10.0f, 10.0f); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec); +} + XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { XlaBuilder builder(TestName()); auto x_rev = diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc index 646875a20a2..80ea4d644c0 100644 --- a/tensorflow/compiler/xla/client/lib/svd.cc +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -837,8 +837,11 @@ SVDResult SVD(XlaOp a, int64 max_iter, float epsilon, auto eps = ScalarLike(a, epsilon); - SVDResult svd_result = - HouseHolderBidiagonalization(a, eps, precision).ValueOrDie(); + auto svd_result_or = HouseHolderBidiagonalization(a, eps, precision); + if (!svd_result_or.ok()) { + return return_error(svd_result_or.status()); + } + SVDResult svd_result = svd_result_or.ValueOrDie(); auto output_with_status = WhileLoopFn( { @@ -861,7 +864,13 @@ SVDResult SVD(XlaOp a, int64 max_iter, float epsilon, svd_result.u = output[1]; svd_result.v = output[2]; svd_result.d = output[3]; - svd_result = SortBySingularValuesAndPostProcessing(svd_result).ValueOrDie(); + + svd_result_or = SortBySingularValuesAndPostProcessing(svd_result); + if (!svd_result_or.ok()) { + return return_error(svd_result_or.status()); + } + svd_result = svd_result_or.ValueOrDie(); + if (maybe_transpose) { std::swap(svd_result.u, svd_result.v); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 52f61408cbb..3e2a4eb53a7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -26,12 +26,15 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -39,6 +42,8 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace xla { @@ -71,6 +76,58 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator, entry->set_id(id); entry->set_name(GetFullName(base_name, separator, id)); } + +ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) { + return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); +} + +void SetInstructionAsConstant(HloInstructionProto* instr, int64 id, + const Shape& shape, bool pred) { + Literal literal = LiteralUtil::CreateR0(pred); + Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); + *instr->mutable_shape() = shape.ToProto(); + *instr->mutable_literal() = literal_broadcast.ToProto(); + *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); +} + +// Converts a HloComputation into ReducerOr with predicate types. +HloComputationProto CreateReduceOr(int64 reducer_id, + HloComputationProto* original_reducer) { + HloComputationProto reducer; + SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); + std::vector operands_id; + for (auto& inst : original_reducer->instructions()) { + // Copy params. + if (StringToHloOpcode(inst.opcode()).ValueOrDie() == + HloOpcode::kParameter) { + HloInstructionProto* new_param = reducer.add_instructions(); + *new_param = inst; + *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + operands_id.push_back(inst.id()); + } + if (inst.id() == original_reducer->root_id()) { + HloInstructionProto* new_root = reducer.add_instructions(); + *new_root = inst; + *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + new_root->clear_operand_ids(); + for (int64 operand_id : operands_id) { + new_root->add_operand_ids(operand_id); + } + reducer.set_root_id(inst.id()); + } + } + return reducer; +} + +bool InstrIsSetBound(const HloInstructionProto* instr_proto) { + HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie(); + if (opcode == HloOpcode::kCustomCall && + instr_proto->custom_call_target() == "SetBound") { + return true; + } + return false; +} } // namespace namespace internal { @@ -247,7 +304,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // GetDimensionSize is always considered constant in XLA -- If a dynamic // dimension is presented, -1 is returned. break; - // Non functional ops. case HloOpcode::kRng: case HloOpcode::kAllReduce: @@ -260,6 +316,11 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. case HloOpcode::kCustomCall: + if (instr.custom_call_target() == "SetBound") { + // Set bound is considered constant -- the bound is used as the value. + break; + } + TF_FALLTHROUGH_INTENDED; case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. @@ -446,7 +507,7 @@ StatusOr XlaBuilder::Build(int64 root_id, alias.param_index.ToString().c_str()); } TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number, - alias.param_index)); + alias.param_index, alias.kind)); } *module->mutable_input_output_alias() = config.ToProto(); return Status::OK(); @@ -529,7 +590,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, - absl::optional direction) { + absl::optional direction, + absl::optional type) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); @@ -587,7 +649,11 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, return InvalidArgument( "kCompare expects a ComparisonDirection, but none provided."); } - return Compare(shape, updated_lhs, updated_rhs, *direction); + if (type == absl::nullopt) { + return Compare(shape, updated_lhs, updated_rhs, *direction); + } else { + return Compare(shape, updated_lhs, updated_rhs, *direction, *type); + } } if (direction.has_value()) { @@ -610,8 +676,16 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { + return Compare(shape, lhs, rhs, direction, + Comparison::DefaultComparisonType(shape.element_type())); +} + +StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type) { HloInstructionProto instr; instr.set_comparison_direction(ComparisonDirectionToString(direction)); + instr.set_comparison_type(ComparisonTypeToString(type)); *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs}); } @@ -1022,6 +1096,36 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, }); } +XlaOp XlaBuilder::DynamicReshape(XlaOp operand, + absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + std::vector dim_size_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes, + GetOperandShapes(dim_sizes)); + + absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const Shape shape, + ShapeInference::InferDynamicReshapeShape( + *operand_shape, dim_size_shape_ptrs, + new_size_bounds, dims_are_dynamic)); + TF_RETURN_IF_ERROR(first_error_); + std::vector operands; + operands.reserve(1 + dim_sizes.size()); + operands.push_back(operand); + for (const XlaOp& dim_size : dim_sizes) { + operands.push_back(dim_size); + } + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape, + operands); + }); +} + XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { @@ -1364,6 +1468,25 @@ StatusOr XlaBuilder::FftInternal( return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); } +StatusOr XlaBuilder::TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) { + HloInstructionProto instr; + *instr.mutable_triangular_solve_options() = std::move(options); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b}); +} + +StatusOr XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) { + HloInstructionProto instr; + xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); + options.set_lower(lower); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); +} + XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1874,7 +1997,6 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) { XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); Shape output_shape = shape; @@ -1893,14 +2015,22 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, return InvalidArgument("Unsupported shape for RngBitGenerator: %s", PrimitiveType_Name(output_shape.element_type())); } - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto(); - instr.set_rng_algorithm(algorithm); - return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, - {initial_state}); + return RngBitGeneratorInternal( + ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm, + initial_state); }); } +StatusOr XlaBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + HloInstructionProto instr; + *instr.mutable_shape() = full_result_shape.ToProto(); + instr.set_rng_algorithm(algorithm); + return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, + {initial_state}); +} + XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -2466,6 +2596,7 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, } *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout; } + instr.set_constrain_layout(true); } *instr.mutable_shape() = shape.ToProto(); @@ -2842,6 +2973,249 @@ StatusOr XlaBuilder::IsConstant(XlaOp operand) const { return is_constant; } +StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + + HloComputationProto entry; + SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator, + GetNextId()); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = + ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); + + std::vector called_computatons; + // Process instruction and copy it into the new graph. The new node in the new + // graph with have id set to `id`. + auto process_instruction = [&](const HloInstructionProto* instr_proto, + bool need_rewrite, int64 id, + absl::Span operand_ids) { + // Rewrite the instruction with following rules: + // - Unary ops: Convert into bitcast (identity) with type Pred. + // - Binary ops: Convert into binary or. + // - Select: Convert into binary or with its two data operands. + // - Concat / Tuple/ GTE / Bitcast: Copy. + // - Param: Convert to constant True. + // - GetDimensionSize: Convert to constant True if dimension is dynamic, + // contant False if dimension is static. + // - Reduce: Convert to reduce or. + // - Constant: Convert to constant False. + // - Other ops: Not supported. + // Create the instruction for the new handle. + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instr_proto->opcode())); + auto* new_instr = entry.add_instructions(); + *new_instr = *instr_proto; + new_instr->set_id(id); + new_instr->mutable_operand_ids()->Clear(); + for (auto operand_id : operand_ids) { + new_instr->mutable_operand_ids()->Add(operand_id); + } + + if (!need_rewrite) { + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, id); + return Status::OK(); + } + *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); + Shape new_shape(new_instr->shape()); + switch (opcode) { + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCos: + case HloOpcode::kClz: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kConvert: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTanh: + CHECK_EQ(instr_proto->operand_ids_size(), 1); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast); + break; + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kCompare: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(instr_proto->operand_ids_size(), 2); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + break; + case HloOpcode::kSelect: + break; + case HloOpcode::kGather: + break; + case HloOpcode::kReduce: { + int64 reducer_id = new_instr->called_computation_ids(0); + called_computatons.push_back( + CreateReduceOr(reducer_id, &embedded_[reducer_id])); + break; + } + case HloOpcode::kTuple: + case HloOpcode::kTranspose: + case HloOpcode::kGetTupleElement: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kReshape: + break; + case HloOpcode::kGetDimensionSize: { + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + SetInstructionAsConstant( + new_instr, id, new_shape, + operand_proto->shape().is_dynamic_dimension(dimension)); + break; + } + case HloOpcode::kConstant: + SetInstructionAsConstant(new_instr, id, new_shape, false); + break; + case HloOpcode::kCustomCall: + if (instr_proto->custom_call_target() == "SetBound") { + SetInstructionAsConstant(new_instr, id, new_shape, true); + break; + } else { + return InvalidArgument( + "Dynamic inferencing on custom call %s is not supported", + instr_proto->DebugString()); + } + case HloOpcode::kParameter: + SetInstructionAsConstant(new_instr, id, new_shape, true); + break; + default: + return InvalidArgument("Dynamic inferencing %s is not supported", + instr_proto->DebugString()); + } + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, id); + return Status::OK(); + }; + + struct WorkItem { + explicit WorkItem(int64 handle, bool need_rewrite) + : handle(handle), need_rewrite(need_rewrite), visited(false) {} + int64 handle; + // If need_rewrite is true, the instruction will be copied and rewrite into + // a pred instruction indicating if each value is dynamic. If need_rewrite + // is false, simply copy the instruction to the output graph. + // E.g., + // For select(P, A, B), we need to rewrite A and B into predicates, but + // don't need to rewrite P. + bool need_rewrite; + // Used in dfs to remember the ids of processed operands of this item. + std::vector processed_operands; + // Whether this node been visited before or not. + bool visited; + }; + // Only copy each pair of {handle, need_rewrite} once. Value is the id in the + // new graph. + absl::flat_hash_map, int64> seen; + // Monotonically increasing id to assign to new instructions. + int64 global_id = 0; + // The result id of the last rewritten item -- return value of last stack + // item. + int64 stacktop_id = -1; + std::vector worklist; + worklist.push_back(WorkItem(root->id(), true)); + while (!worklist.empty()) { + WorkItem& item = worklist.back(); + auto item_key = std::make_pair(item.handle, item.need_rewrite); + auto iter = seen.find(item_key); + // Already processed this item. Return previous results. + if (iter != seen.end()) { + stacktop_id = iter->second; + worklist.pop_back(); + continue; + } + + int64 next_operand = item.processed_operands.size(); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(item.handle)); + VLOG(3) << "Processing" << instr_proto->name(); + if (!item.visited) { + item.visited = true; + } else { + // Record previous processed operand. + item.processed_operands.push_back(stacktop_id); + next_operand++; + } + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instr_proto->opcode())); + if (next_operand >= instr_proto->operand_ids_size() || + opcode == HloOpcode::kGetDimensionSize || + InstrIsSetBound(instr_proto)) { + // No more operands to process, process self. + int64 new_id = ++global_id; + VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); + TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite, + new_id, item.processed_operands)); + stacktop_id = new_id; + seen[item_key] = stacktop_id; + worklist.pop_back(); + continue; + } + + WorkItem next_item(instr_proto->operand_ids(next_operand), true); + if (opcode == HloOpcode::kSelect && next_operand == 0) { + next_item.need_rewrite = false; + } + if (opcode == HloOpcode::kGather && next_operand == 1) { + next_item.need_rewrite = false; + } + // Push next operand into worklist. + worklist.push_back(next_item); + } + TF_RET_CHECK(stacktop_id != -1); + entry.set_root_id(stacktop_id); + absl::c_sort(*entry.mutable_instructions(), + [](const HloInstructionProto& p1, + const HloInstructionProto& p2) { return p1.id() < p2.id(); }); + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_host_program_shape() = *program_shape; + for (auto& called_comp : called_computatons) { + *module->add_computations() = called_comp; + } + *module->add_computations() = std::move(entry); + XLA_VLOG_LINES(3, module->DebugString()); + return std::move(computation); +} + StatusOr XlaBuilder::BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); @@ -2886,26 +3260,33 @@ StatusOr XlaBuilder::BuildConstantSubGraph( LookUpInstructionByHandle(handle)); if (instr_proto->opcode() == - HloOpcodeString(HloOpcode::kGetDimensionSize)) { - // At this point, BuildConstantSubGraph should never encounter a - // GetDimensionSize with a dynamic dimension. IsConstant check would have - // failed at the beginning of this function. - // - // Replace GetDimensionSize with a Constant representing the static bound - // of the shape. - int64 dimension = instr_proto->dimensions(0); - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); + HloOpcodeString(HloOpcode::kGetDimensionSize) || + InstrIsSetBound(instr_proto)) { + int32 constant_value = -1; + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would + // have failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static + // bound of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); - int32 constant_dimension_size = -1; - if (!(operand_proto->shape().is_dynamic_dimension(dimension) && - dynamic_dimension_is_minus_one)) { - constant_dimension_size = - static_cast(operand_proto->shape().dimensions(dimension)); + if (!(operand_proto->shape().is_dynamic_dimension(dimension) && + dynamic_dimension_is_minus_one)) { + constant_value = + static_cast(operand_proto->shape().dimensions(dimension)); + } + } else { + TF_RET_CHECK( + absl::SimpleAtoi(instr_proto->backend_config(), &constant_value)); } - Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + Literal literal = LiteralUtil::CreateR0(constant_value); HloInstructionProto const_instr; *const_instr.mutable_shape() = literal.shape().ToProto(); @@ -2937,6 +3318,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { continue; } + if (InstrIsSetBound(instr_src)) { + continue; + } auto* instr = entry.add_instructions(); *instr = *instr_src; @@ -3215,6 +3599,13 @@ XlaOp Reshape(const Shape& shape, XlaOp operand) { return operand.builder()->Reshape(shape, operand); } +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds, + dims_are_dynamic); +} + XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension) { @@ -3274,31 +3665,71 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs, return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } +XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq, + compare_type); +} + XlaOp Ne(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe); } +XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe, + compare_type); +} + XlaOp Ge(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe); } +XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe, + compare_type); +} + XlaOp Gt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt); } +XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt, + compare_type); +} + XlaOp Le(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe); } +XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe, + compare_type); +} XlaOp Lt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); } +XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt, + Comparison::Type::kFloatTotalOrder); +} + XlaOp Compare(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction) { @@ -3306,6 +3737,13 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type) { + return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs, + broadcast_dimensions, direction, compare_type); +} + XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { return Compare(lhs, rhs, {}, direction); } @@ -3386,36 +3824,26 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, TriangularSolveOptions::Transpose transpose_a) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b)); - xla::TriangularSolveOptions& options = - *instr.mutable_triangular_solve_options(); + xla::TriangularSolveOptions options; options.set_left_side(left_side); options.set_lower(lower); options.set_unit_diagonal(unit_diagonal); options.set_transpose_a(transpose_a); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape( *a_shape, *b_shape, options)); - *instr.mutable_shape() = shape.ToProto(); - - return builder->AddInstruction(std::move(instr), - HloOpcode::kTriangularSolve, {a, b}); + return builder->TriangularSolveInternal(shape, a, b, std::move(options)); }); } XlaOp Cholesky(XlaOp a, bool lower) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); - xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); - options.set_lower(lower); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCholeskyShape(*a_shape)); - *instr.mutable_shape() = shape.ToProto(); - - return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); + return builder->CholeskyInternal(shape, a, lower); }); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 1960d0c4632..cd9809c2a20 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -163,6 +164,15 @@ class XlaBuilder { // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + // Similar to SetOpMetadata, but only set the metadata for the next op. void SetOneShotOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); @@ -277,6 +287,31 @@ class XlaBuilder { StatusOr BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_uint_max = false); + // Similar to BuildConstantSubGraph, but with root element type changed to + // boolean. A true value in the root indicates that the value is dynamic while + // false value indicates that the value is a constant. This will copy the + // needed ops/computations to the subgraph. + // + // E.g., + // Compuptation { + // a = 3 + // b = param(0) + // ROOT Tuple(a + b, a + 1, b + 1) + // } + // Calling BuildDynamicInferenceGraph on root will produce the following + // graph: + // + // Compuptation { + // a = False + // b = True + // ROOT Tuple(a | b, a, b) + // } + // + // The result, which is (True, False, True) after evaluation, can be + // interpreted as "First element is dynamic; Second element is static; Third + // element is dynamic". + StatusOr BuildDynamicInferenceGraph(XlaOp root_op); + // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while @@ -340,6 +375,7 @@ class XlaBuilder { // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. + ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.") Status SetDynamicBinding(int64 dynamic_size_param_num, ShapeIndex dynamic_size_param_index, int64 target_param_num, @@ -349,12 +385,16 @@ class XlaBuilder { // not available until the computation is built, and eventual error in the // arguments of this API will be detected only at computation Build() time. // - // Note: Aliasing API is 'may-alias' and only donated buffer at runtime will - // be aliased with output. If a buffer is not donated at runtime, a copy will - // be inserted by XLA to prevent buffer clobbering. + // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' + // and only donated buffer at runtime will be aliased with output. If a buffer + // is not donated at runtime, a copy will be inserted by XLA to prevent buffer + // clobbering. void SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - input_output_aliases_.push_back({output_index, param_number, param_index}); + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind kind = + HloInputOutputAliasConfig::AliasKind::kMayAlias) { + input_output_aliases_.push_back( + {output_index, param_number, param_index, kind}); } // Describes an input/output alias as inserted by the SetUpAlias() API. @@ -365,6 +405,8 @@ class XlaBuilder { int64 param_number; // Specifies the index of the aliased buffer in the parameter ShapeIndex param_index; + // Specifies if the alias is a must alias or may alias. + HloInputOutputAliasConfig::AliasKind kind; }; // Looks up the HloInstruction and sets the frontend attribute "attribute" to @@ -422,6 +464,10 @@ class XlaBuilder { XlaOp Reshape(const Shape& shape, XlaOp operand, int64 inferred_dimension = -1); + XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + XlaOp Collapse(XlaOp operand, absl::Span dimensions); XlaOp Slice(XlaOp operand, absl::Span start_indices, @@ -521,6 +567,12 @@ class XlaBuilder { FftType fft_type, absl::Span fft_length); + virtual StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); + + virtual StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower); + XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); virtual StatusOr InfeedWithTokenInternal( @@ -669,6 +721,11 @@ class XlaBuilder { XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape); + // Internal variant for the op with the full result shape containing both data + // and state shape as a tuple. + virtual StatusOr RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state); XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init); @@ -741,8 +798,13 @@ class XlaBuilder { XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); - StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, - absl::Span operands = {}); + virtual StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands); + StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { + return AddInstruction(std::move(instr), opcode, /*operands=*/{}); + } void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); @@ -760,14 +822,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. The direction is // only used if opcode is kCompare. - XlaOp BinaryOp( - HloOpcode binop, XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - absl::optional direction = absl::nullopt); + XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + absl::optional direction = absl::nullopt, + absl::optional type = absl::nullopt); // Internal helper method for binary op compare without broadcast dimensions. virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - Comparison::Direction direction); + ComparisonDirection direction); + virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type); // Internal helper method that does the building for an arbitrary binary op // with same ranked operands that doesn't broadcast. @@ -905,6 +970,10 @@ class XlaBuilder { friend XlaOp Reshape(const Shape& shape, XlaOp operand); + friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension); @@ -933,22 +1002,13 @@ class XlaBuilder { friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index); - friend XlaOp Eq(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Ne(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Ge(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Gt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Lt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Le(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, + Comparison::Type compare_type); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, @@ -1288,6 +1348,25 @@ class XlaScopedFrontendAttributesAssignment { TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); }; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment); +}; + // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. @@ -1427,9 +1506,16 @@ XlaOp Pad(XlaOp operand, XlaOp padding_value, XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes); -// Enqueues an operation onto the computation that collapses the operand, from -// first to last dimension (C order), then reshapes it to the given dimension -// sizes. Conceptually, this is a limited form of "shape casting". +// Enqueues a dynamic reshape operation. The dynamic reshape takes additional +// XlaOps as sizes for the result dimension. The result dim i is a dynamic +// dimension dimension if dims_are_dynamic[i] is true. +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + +// Enqueues an operation onto the computation that collapses the operand, +// from first to last dimension (C order), then reshapes it to the given +// dimension sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span new_sizes); // Enqueues a Reshape op that uses an explicit target shape. @@ -1542,29 +1628,44 @@ XlaOp GetTupleElement(XlaOp tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a comparison instruction onto the computation (optionally without // broadcast_dimensions for consistency with others). +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type); XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc index 47fb69e3bce..06dd9642cac 100644 --- a/tensorflow/compiler/xla/comparison_util.cc +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -54,32 +54,59 @@ StatusOr StringToComparisonDirection( return it->second; } -Comparison::Comparison(Direction dir, PrimitiveType type) : dir_(dir) { +StatusOr StringToComparisonType( + absl::string_view compare_type_name) { + static auto* type_map = new absl::flat_hash_map({ + {"FLOAT", Comparison::Type::kFloat}, + {"TOTALORDER", Comparison::Type::kFloatTotalOrder}, + {"SIGNED", Comparison::Type::kSigned}, + {"UNSIGNED", Comparison::Type::kUnsigned}, + }); + auto it = type_map->find(compare_type_name); + if (it == type_map->end()) { + return InvalidArgument("Unknown comparison type: %s", compare_type_name); + } + return it->second; +} + +std::string ComparisonTypeToString(Comparison::Type type) { + switch (type) { + case Comparison::Type::kFloat: + return "FLOAT"; + case Comparison::Type::kFloatTotalOrder: + return "TOTALORDER"; + case Comparison::Type::kSigned: + return "SIGNED"; + case Comparison::Type::kUnsigned: + return "UNSIGNED"; + } +} + +Comparison::Comparison(Direction dir, PrimitiveType type) + : dir_(dir), type_(DefaultComparisonType(type)) {} + +Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) { switch (type) { case S8: case S16: case S32: case S64: - type_ = Type::kSigned; - break; + return Type::kSigned; case PRED: case U8: case U16: case U32: case U64: - type_ = Type::kUnsigned; - break; + return Type::kUnsigned; case F16: case F32: case BF16: case F64: case C64: case C128: - type_ = Type::kFloat; - break; + return Type::kFloat; default: LOG(FATAL) << "Unsupported comparison mode." - << ComparisonDirectionToString(dir) << ":" << PrimitiveType_Name(type) << "\n"; } } @@ -164,20 +191,6 @@ bool Comparison::IsAntireflexive() const { } } -/* static */ const char* Comparison::ComparisonTypeToString( - Comparison::Type type) { - switch (type) { - case Type::kFloat: - return "f"; - case Type::kFloatTotalOrder: - return "ft"; - case Type::kSigned: - return "s"; - case Type::kUnsigned: - return "u"; - } -} - std::string Comparison::ToString(std::string prefix1, std::string prefix2) const { return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 + diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h index 11335c6b5ba..33ae2c67106 100644 --- a/tensorflow/compiler/xla/comparison_util.h +++ b/tensorflow/compiler/xla/comparison_util.h @@ -103,11 +103,11 @@ class Comparison { bool Compare(const T a, const T b) const { return GetComparator()(a, b); } + static Type DefaultComparisonType(PrimitiveType t); private: static Direction Converse(Direction dir); static Direction Inverse(Direction dir); - static const char* ComparisonTypeToString(Type type); const Direction dir_; Type type_; @@ -117,10 +117,14 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { return os << cmp.ToString(); } string ComparisonDirectionToString(Comparison::Direction direction); +std::string ComparisonTypeToString(Comparison::Type type); StatusOr StringToComparisonDirection( absl::string_view direction_name); +StatusOr StringToComparisonType( + absl::string_view compare_type_name); + using ComparisonDirection = Comparison::Direction; } // namespace xla diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 16563bab5bc..a926e8b3c88 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -89,6 +89,32 @@ class Sharding(object): tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) + @classmethod + def partial_tile(cls, tile_assignment): + """Returns a partially tiled sharding attribute. + + This is similar to tile(), but tile_assignment has one more dimension than + the tensor, and tiles in the last dimension of tile_assignment are + replicated. + + Args: + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type. + """ + if not isinstance(tile_assignment, _np.ndarray): + raise TypeError('PartialTile assignment must be of type np.ndarray') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices), + replicate_on_last_tile_dim=True)) + @classmethod def split(cls, tensor, split_dimension, num_devices, input_shape=None): """Returns a Sharding that splits a tensor across a dimension. @@ -245,6 +271,23 @@ def split(tensor, return tensor +def partial_tile(tensor, tile_assignment, use_sharding_op=False): + """Returns a tensor that has tiled sharding. + + Args: + tensor: A tf.Tensor to shard. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. It must have one + more dimension than tensor, and the last dimension represents partially + replicated tiles. + use_sharding_op: If true, adds a sharding op to set the sharding. + """ + if use_sharding_op: + tensor = tf2xla.sharding(tensor) + Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor) + return tensor + + def get_op_sharding(op): """Returns sharding attribute of an op. @@ -313,20 +356,30 @@ def mesh_split(tensor, use_sharding_op: If true, adds a sharding op to set the sharding. Raises: - ValueError: The number of tensor split dimensions is different from device - mesh rank. + ValueError: The number of tensor split dimensions is larger than device mesh + rank. """ permutation = [d for d in tensor_split_dims_mapping if d >= 0] - if len(permutation) != len(device_mesh.shape): + if len(permutation) > len(device_mesh.shape): raise ValueError( - 'Number of tensor split dimensions (%r) is different from device mesh ' + 'Number of tensor split dimensions (%r) is larger than device mesh ' 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % (len(permutation), len( device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) - tile_assignment = _np.transpose(device_mesh, permutation) + # Append replicated dimensions to the end. + transpose_permutation = permutation + [ + d for d in range(len(device_mesh.shape)) if d not in permutation + ] + tile_assignment = _np.transpose(device_mesh, transpose_permutation) tile_shape = [ 1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping ] + partial = len(permutation) < len(device_mesh.shape) + if partial: + tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape)) tile_assignment = _np.reshape(tile_assignment, tile_shape) + if partial: + return partial_tile( + tensor, tile_assignment, use_sharding_op=use_sharding_op) return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op) diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index 51d666fba9a..45abd9b4c92 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -121,8 +121,8 @@ example. ### AOT (Ahead-of-time) compilation for CPU with `tfcompile` -You can also use a standalone [`tfcompile`](./tfcompile) tool, -which converts TensorFlow graph into executable code (for x86-64 CPU only). +You can also use a standalone [`tfcompile`](./tfcompile.md) tool, which converts +TensorFlow graph into executable code (for x86-64 CPU only). ## Inspect compiled programs @@ -196,7 +196,7 @@ Apart from TensorFlow, XLA programs can be generated by: [XLA source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla) on Github! - diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 3031bfbf2e2..051c1539f6b 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1235,7 +1235,10 @@ floating-point types. Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge` (greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt` -(less-than). +(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder, +GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities, +except that they additionally support a total order over the floating point +numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN. Arguments | Type | Semantics --------- | ------- | ---------------------------------------- diff --git a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb index c0160f2766c..d7799093583 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb @@ -169,7 +169,7 @@ " model.set_weights(initial_weights)\n", "\n", "warmup(model, x_train, y_train, x_test, y_test)\n", - "%time train_model(model, x_train, y_train, x_test, y_test)\n", + "train_model(model, x_train, y_train, x_test, y_test)\n", "\n", "scores = model.evaluate(x_test, y_test, verbose=1)\n", "print('Test loss:', scores[0])\n", diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3807e6d3a56..d26e0881c53 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1004,14 +1004,20 @@ absl::optional LiteralBase::GetIntegralAsS64( switch (shape().element_type()) { case PRED: return Get(multi_index); + case S8: + return Get(multi_index); case U8: return Get(multi_index); + case S16: + return Get(multi_index); + case U16: + return Get(multi_index); case S32: return Get(multi_index); - case S64: - return Get(multi_index); case U32: return Get(multi_index); + case S64: + return Get(multi_index); case U64: return Get(multi_index); default: diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 6e61e0600a0..54240587282 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -59,6 +59,10 @@ cc_library( name = "tracked_device_buffer", srcs = ["tracked_device_buffer.cc"], hdrs = ["tracked_device_buffer.h"], + visibility = [ + "//learning/pathways/data_parallel:__pkg__", + "//tensorflow:internal", + ], deps = [ ":event_pool", ":local_device_state", @@ -204,6 +208,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core/common_runtime:bfc_allocator", "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", + "//tensorflow/core:lib_internal", "//tensorflow/stream_executor:tf_allocator_adapter", ] + if_cuda(["@local_config_nccl//:nccl"]), ) diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index be70c16fc12..e2543bda7df 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -25,8 +25,8 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kCpuPlatformName, - /*device_kind=*/kCpuPlatformName) {} + : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, + /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -39,7 +39,7 @@ StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutorConfig config; config.ordinal = i; diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index c70d90ae228..ad0079b1c4a 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class CpuDevice : public Device { +class CpuDevice : public PjRtDevice { public: CpuDevice(int id, std::unique_ptr local_device_state); }; diff --git a/tensorflow/compiler/xla/pjrt/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD index 5cada95390c..175b4268dda 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/BUILD +++ b/tensorflow/compiler/xla/pjrt/distributed/BUILD @@ -52,6 +52,9 @@ cc_library( tf_cc_test( name = "service_test", srcs = ["service_test.cc"], + tags = [ + "nomsan", # b/163629207 + ], deps = [ ":protocol_proto_cc", ":service", @@ -106,6 +109,9 @@ cc_library( tf_cc_test( name = "client_server_test", srcs = ["client_server_test.cc"], + tags = [ + "nomsan", # b/163629207 + ], deps = [ ":client", ":protocol_proto_cc", diff --git a/tensorflow/compiler/xla/pjrt/distributed/client.cc b/tensorflow/compiler/xla/pjrt/distributed/client.cc index 55b02c6a09e..43c0c7b277d 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/client.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client.cc @@ -17,6 +17,7 @@ limitations under the License. #include // NOLINT +#include "absl/time/time.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" #include "tensorflow/compiler/xla/pjrt/distributed/util.h" @@ -36,6 +37,7 @@ xla::Status DistributedRuntimeClient::Connect( ctx.set_deadline(absl::ToChronoTime(absl::Now() + rpc_timeout_)); ConnectRequest request; request.set_protocol_version(kDistributedRuntimeProtocolVersion); + request.set_timeout_milliseconds(absl::ToInt64Milliseconds(rpc_timeout_)); *request.mutable_local_topology() = local_topology; VLOG(10) << "Connect: " << request.DebugString(); ConnectResponse response; diff --git a/tensorflow/compiler/xla/pjrt/distributed/protocol.h b/tensorflow/compiler/xla/pjrt/distributed/protocol.h index 4daa939ac8d..e8be43006f7 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/protocol.h +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.h @@ -18,7 +18,7 @@ limitations under the License. namespace xla { -static constexpr int kDistributedRuntimeProtocolVersion = 1; +static constexpr int kDistributedRuntimeProtocolVersion = 2; } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/distributed/protocol.proto b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto index 18bfa221110..c3bbb3a7f5d 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/protocol.proto +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto @@ -61,6 +61,7 @@ message ConnectRequest { int32 protocol_version = 1; // Always 1 at present. LocalTopologyProto local_topology = 2; + int32 timeout_milliseconds = 3; } message ConnectResponse { diff --git a/tensorflow/compiler/xla/pjrt/distributed/service.cc b/tensorflow/compiler/xla/pjrt/distributed/service.cc index 3325fcd8319..868529637de 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/service.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "absl/time/time.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" #include "tensorflow/compiler/xla/pjrt/distributed/util.h" #include "tensorflow/compiler/xla/status.h" @@ -69,11 +70,12 @@ void BuildGlobalTopology(absl::Span local_topologies, mu_.AssertHeld(); return num_nodes_present_ == nodes_.size(); }; + auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds()); if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_present), - kConnectTimeout)) { + connect_timeout)) { return ToGrpcStatus(tensorflow::errors::DeadlineExceeded( "Timed out after %s waiting for all nodes to call Connect()", - absl::FormatDuration(kConnectTimeout))); + absl::FormatDuration(connect_timeout))); } if (node_id == 0) { diff --git a/tensorflow/compiler/xla/pjrt/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h index 9ecbdb3cc7c..fe323d9f3b2 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -50,8 +50,6 @@ class DistributedRuntimeServiceImpl final KeyValueSetResponse* response) override; private: - const absl::Duration kConnectTimeout = absl::Seconds(120); - absl::Mutex mu_; enum class State { kInitializing, kRunning }; State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing; diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index d54be61fbb8..298c41c7f58 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -32,7 +32,7 @@ TEST(GpuMultiStream, Basics) { GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), /*distributed_client=*/nullptr, /*node_id=*/0)); - Device* device = client->local_devices().at(0); + PjRtDevice* device = client->local_devices().at(0); int n = 1024; Shape shape = ShapeUtil::MakeShape(S32, {n}); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index f7138a8c181..c1149f2dbf9 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -25,8 +25,8 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kInterpreterPlatformName, - /*device_kind=*/kInterpreterPlatformName) {} + : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, + /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -40,7 +40,7 @@ StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; se::StreamExecutor* executor = client->backend().stream_executor(0).ValueOrDie(); auto device_state = absl::make_unique( diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index 58b210ad762..cf732f70124 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class InterpreterDevice : public Device { +class InterpreterDevice : public PjRtDevice { public: InterpreterDevice(int id, std::unique_ptr local_device_state); diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index edffaf6c877..6e387f8738f 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" namespace xla { @@ -89,12 +90,20 @@ StatusOr> CreateBFCAllocator( CHECK_GT(local_devices.size(), 0); const se::Platform* platform = local_devices.front()->executor()->platform(); std::vector allocators; + bool enable_unified_memory; + Status status = tensorflow::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", + false, &enable_unified_memory); + if (!status.ok()) { + LOG(ERROR) << "Unable to read TF_FORCE_UNIFIED_MEMORY: " + << status.error_message(); + } + for (auto& local_device : local_devices) { se::StreamExecutor* executor = local_device->executor(); int device_ordinal = executor->device_ordinal(); auto sub_allocator = absl::make_unique( executor, tensorflow::PlatformGpuId(device_ordinal), - /*use_unified_memory=*/false, + /*use_unified_memory=*/enable_unified_memory, /*alloc_visitors=*/std::vector(), /*free_visitors=*/std::vector()); @@ -104,7 +113,10 @@ StatusOr> CreateBFCAllocator( return Unavailable("Failed to query available memory from device %i", device_ordinal); } - size_t allocator_memory = free_memory * memory_fraction; + // To allow full GPU memory to be visible to the BFC allocator if using + // unified memory. + size_t allocator_memory = + enable_unified_memory ? total_memory : free_memory * memory_fraction; if (preallocate) { LOG(INFO) << "XLA backend allocating " << allocator_memory << " bytes on device " << device_ordinal @@ -207,9 +219,9 @@ StatusOr NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { return cache_.emplace(key_string, result.ValueOrDie()).first->second; } -std::vector> BuildLocalDevices( +std::vector> BuildLocalDevices( std::vector> local_device_states) { - std::vector> devices; + std::vector> devices; for (auto& local_device : local_device_states) { int device_ordinal = local_device->device_ordinal(); const se::DeviceDescription& description = @@ -225,7 +237,7 @@ std::vector> BuildLocalDevices( Status BuildDistributedDevices( std::vector> local_device_states, std::shared_ptr distributed_client, int node_id, - std::vector>* devices, + std::vector>* devices, GpuExecutableRunOptions* gpu_executable_run_options) { LocalTopologyProto local_topology; local_topology.set_node_id(node_id); @@ -286,8 +298,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : Device(id, std::move(local_device_state), kGpuPlatformName, - std::move(device_kind), node_id) {} + : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, + std::move(device_kind), node_id) {} StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, @@ -302,7 +314,7 @@ StatusOr> GetNvidiaGpuClient( auto host_memory_allocator = GetGpuHostAllocator(local_device_states.front()->executor()); - std::vector> devices; + std::vector> devices; auto gpu_run_options = absl::make_unique(); if (distributed_client) { TF_RETURN_IF_ERROR(BuildDistributedDevices( diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index bf59ddef3a9..4f22a169bd8 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -25,7 +25,7 @@ limitations under the License. namespace xla { -class GpuDevice : public Device { +class GpuDevice : public PjRtDevice { public: GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index c5dce4a37f7..099c7729679 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -112,19 +112,19 @@ limitations under the License. namespace xla { -StatusOr Device::GetLocalDeviceState() const { +StatusOr PjRtDevice::GetLocalDeviceState() const { if (local_device_state_) { return local_device_state_.get(); } return InvalidArgument("Device %s is not a local device.", DebugString()); } -std::string Device::DebugString() const { +std::string PjRtDevice::DebugString() const { return absl::StrCat(platform_name(), ":", id()); } StatusOr DevicesToDeviceAssignment( - absl::Span> devices) { + absl::Span> devices) { if (devices.empty()) { return InvalidArgument( "Device assignment passed to Compile() must be non-empty."); @@ -175,7 +175,7 @@ class CpuAllocator : public tensorflow::Allocator { PjRtClient::PjRtClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -201,7 +201,7 @@ PjRtClient::PjRtClient( host_memory_allocator_ = std::make_unique(); } - for (const std::unique_ptr& device : devices_) { + for (const std::unique_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); @@ -376,8 +376,9 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer, // It is safe to delete the returned PjRtBuffer without further // synchronization if an error occurs before the buffer is used. StatusOr> AllocateDestinationBuffer( - const Shape& on_host_shape, Device* device, LocalDeviceState* local_device, - se::Stream* copy_stream, bool is_uninitialized_create, PjRtClient* client) { + const Shape& on_host_shape, PjRtDevice* device, + LocalDeviceState* local_device, se::Stream* copy_stream, + bool is_uninitialized_create, PjRtClient* client) { if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { return InvalidArgument("Can't make a buffer from an empty tuple"); } @@ -574,7 +575,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, std::shared_ptr buffer_reference, PjRtClient* client, - Device* device) { + PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); @@ -736,7 +737,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( /* static */ StatusOr> PjRtBuffer::CreateUninitialized( - const Shape& shape, PjRtClient* client, Device* device) { + const Shape& shape, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized"); VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString() << " device: " << device->DebugString(); @@ -755,7 +756,7 @@ StatusOr> PjRtBuffer::CreateUninitialized( /* static */ StatusOr> PjRtBuffer::FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, Device* device) { + const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); @@ -815,7 +816,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( } /*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( - absl::Span shapes, PjRtClient* client, Device* device, + absl::Span shapes, PjRtClient* client, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { notifier(InvalidArgument( @@ -849,7 +850,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - PjRtClient* client, Device* device) + PjRtClient* client, PjRtDevice* device) : client_(client), on_host_shape_(std::move(on_host_shape)), on_device_shape_(std::move(on_device_shape)), @@ -1189,7 +1190,7 @@ PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) { StatusOr, std::shared_ptr>> PjRtBuffer::CopyToDeviceHelper( - Device* dst_device, LocalDeviceState* dst_local_device, + PjRtDevice* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, std::shared_ptr src_device_buffer) { TF_ASSIGN_OR_RETURN( @@ -1249,7 +1250,7 @@ PjRtBuffer::CopyToDeviceHelper( } StatusOr> PjRtBuffer::CopyToDevice( - Device* dst_device) { + PjRtDevice* dst_device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice"); if (dst_device == device_) { return InvalidArgument( @@ -1342,8 +1343,6 @@ namespace { // Helper struct for the tuple that is transiently constructed to hold the // arguments of an execution. struct TupleHandle { - // The tuple's shape on the host. - Shape on_host_shape; // The ExecutionInput describing the tuple. ExecutionInput execution_input; // A definition event that has been recorded on the host_to_device stream @@ -1414,8 +1413,7 @@ StatusOr MakeTupleHelper( auto transfer_event = std::make_shared(); transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); - return TupleHandle({std::move(on_host_shape), std::move(execution_input), - std::move(transfer_event)}); + return TupleHandle({std::move(execution_input), std::move(transfer_event)}); } // Converts a ScopedShapedBuffer returned from an execution into a @@ -1423,20 +1421,20 @@ StatusOr MakeTupleHelper( std::unique_ptr OutputBufferHelper( ScopedShapedBuffer* result_buffer, std::shared_ptr definition_event, PjRtClient* client, - Device* device, LocalDeviceState* local_device) { + PjRtDevice* device, LocalDeviceState* local_device) { std::shared_ptr out_buffer = TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer, {definition_event}); - auto py_buffer = absl::make_unique( + auto pjrt_buffer = absl::make_unique( result_buffer->on_host_shape(), result_buffer->on_device_shape(), std::move(out_buffer), client, device); - RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, + RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), /*prefer_to_retain_reference=*/false); - return py_buffer; + return pjrt_buffer; } -static Device* LookupDevice(const PjRtClient& client, int device_id) { +static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -1450,7 +1448,7 @@ PjRtExecutable::PjRtExecutable( bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector> local_logical_device_ids, - std::vector local_devices, PjRtClient* client) + std::vector local_devices, PjRtClient* client) : client_(client), device_assignment_(std::move(device_assignment)), parameter_is_tupled_arguments_(parameter_is_tupled_arguments), @@ -1508,15 +1506,64 @@ const std::string& PjRtExecutable::name() const { } } +bool PjRtExecutable::MustDonateParameter(int executable_idx, + int parameter) const { + return parameters_that_must_be_donated_[executable_idx].contains(parameter); +} + +StatusOr> +PjRtExecutable::MakeExecutionInputsAndWaitForEvents( + int device_ordinal, const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span device_buffers, + absl::flat_hash_set& events) const { + std::vector execution_inputs; + LocalDeviceState* device_state = &client_->device_state(device_ordinal); + // Lift tuple_handle outside the conditional so that the event it returns is + // not destroyed until after the loop below that waits on events. + absl::optional tuple_handle; + if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { + TF_ASSIGN_OR_RETURN(tuple_handle, + MakeTupleHelper(client_, device_state, argument_handles, + device_buffers, device_ordinal)); + events.insert(tuple_handle->event.get()); + execution_inputs.emplace_back(std::move(tuple_handle->execution_input)); + } else { + execution_inputs.reserve(argument_handles.size()); + for (int i = 0; i < argument_handles.size(); ++i) { + PjRtBuffer* handle = argument_handles[i]; + + // Make an ExecutionInput from the device buffer. + execution_inputs.emplace_back(handle->on_device_shape(), + handle->on_host_shape()); + ExecutionInput& execution_input = execution_inputs.back(); + ShapeTree::iterator input_iterator = + execution_input.MutableBuffers()->begin(); + ShapeTree::iterator iterator_end = + execution_input.MutableBuffers()->end(); + device_buffers[i].AddToInput(&input_iterator, iterator_end, + &execution_input, client_->allocator()); + CHECK(input_iterator == iterator_end); + } + } + + for (BufferSequencingEvent* event : events) { + event->WaitForEventOnStream(device_state->compute_stream()); + } + + return execution_inputs; +} + // Enqueues a computation onto the compute stream. Each buffer returned in // device_buffers has a usage hold added that must be dropped on error or // converted on success. StatusOr PjRtExecutable::EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, - Device* device, std::vector* device_buffers, + PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const { int device_ordinal = device->local_device_state()->device_ordinal(); + LocalDeviceState* device_state = &client_->device_state(device_ordinal); tensorflow::profiler::TraceMeConsumer activity( "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, run_id.ToInt()); @@ -1524,10 +1571,7 @@ StatusOr PjRtExecutable::EnqueueExecution( << " mapped to device ordinal for execution: " << device_ordinal; absl::flat_hash_set events; - std::vector execution_inputs; device_buffers->reserve(argument_handles.size()); - const absl::flat_hash_set& parameters_that_must_be_donated = - parameters_that_must_be_donated_[executable_idx]; for (int i = 0; i < argument_handles.size(); ++i) { PjRtBuffer* handle = argument_handles[i]; if (handle->device() != device) { @@ -1536,8 +1580,7 @@ StatusOr PjRtExecutable::EnqueueExecution( "device %s, but replica is assigned to device %s.", i, replica, handle->device()->DebugString(), device->DebugString()); } - bool must_donate = parameters_that_must_be_donated.find(i) != - parameters_that_must_be_donated.end(); + bool must_donate = MustDonateParameter(executable_idx, i); device_buffers->emplace_back(handle->GetBufferWithHold( must_donate ? PjRtBuffer::ScopedHold::kDonation : PjRtBuffer::ScopedHold::kUsage)); @@ -1571,37 +1614,10 @@ StatusOr PjRtExecutable::EnqueueExecution( } } - LocalDeviceState* device_state = &client_->device_state(device_ordinal); - absl::optional tuple_handle; - if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { - TF_ASSIGN_OR_RETURN(tuple_handle, - MakeTupleHelper(client_, device_state, argument_handles, - *device_buffers, device_ordinal)); - events.insert(tuple_handle->event.get()); - execution_inputs.emplace_back(std::move(tuple_handle->execution_input)); - } else { - execution_inputs.reserve(argument_handles.size()); - for (int i = 0; i < argument_handles.size(); ++i) { - PjRtBuffer* handle = argument_handles[i]; - - const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i]; - // Make an ExecutionInput from the device buffer. - execution_inputs.emplace_back(handle->on_device_shape(), - handle->on_host_shape()); - ExecutionInput& execution_input = execution_inputs.back(); - ShapeTree::iterator input_iterator = - execution_input.MutableBuffers()->begin(); - ShapeTree::iterator iterator_end = - execution_input.MutableBuffers()->end(); - device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input, - client_->allocator()); - CHECK(input_iterator == iterator_end); - } - } - - for (BufferSequencingEvent* event : events) { - event->WaitForEventOnStream(device_state->compute_stream()); - } + TF_ASSIGN_OR_RETURN( + std::vector execution_inputs, + MakeExecutionInputsAndWaitForEvents( + device_ordinal, options, argument_handles, *device_buffers, events)); ExecutableRunOptions run_options; run_options.set_stream(device_state->compute_stream()); @@ -1676,11 +1692,45 @@ StatusOr PjRtExecutable::EnqueueExecution( return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult(); } +std::vector> PjRtExecutable::MakeOutputBuffers( + int device_ordinal, const ExecuteOptions& options, + ScopedShapedBuffer result_buffer, + std::shared_ptr definition_event, + PjRtDevice* device) const { + std::vector> outputs; + LocalDeviceState* device_state = &client_->device_state(device_ordinal); + if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { + int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); + outputs.reserve(tuple_count); + // Take ownership of each of the output values, leaving only the root table + // in result_buffer. + for (int i = 0; i < tuple_count; ++i) { + ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i}); + outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event, + client_, device, device_state)); + } + if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { + // Don't release the root buffer until after execution completes. + ShapedBuffer root_buffer_holder = result_buffer.release(); + se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer(); + device_state->ThenExecuteOnCallbackThread( + device_state->compute_stream(), + [root_buffer, allocator{client_->allocator()}, device_ordinal]() { + TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); + }); + } + } else { + outputs.push_back(OutputBufferHelper(&result_buffer, definition_event, + client_, device, device_state)); + } + return outputs; +} + StatusOr>> PjRtExecutable::ExecuteHelper(absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, - Device* device) const { + PjRtDevice* device) const { std::shared_ptr device_assignment; if (device == nullptr) { CHECK(device_assignment_ != nullptr); @@ -1737,31 +1787,9 @@ PjRtExecutable::ExecuteHelper(absl::Span argument_handles, } auto definition_event = std::make_shared(); definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); - std::vector> outputs; - if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { - int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); - outputs.reserve(tuple_count); - // Take ownership of each of the output values, leaving only the root table - // in result_buffer. - for (int i = 0; i < tuple_count; ++i) { - ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i}); - outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event, - client_, device, device_state)); - } - if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { - // Don't release the root buffer until after execution completes. - ShapedBuffer root_buffer_holder = result_buffer.release(); - se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer(); - device_state->ThenExecuteOnCallbackThread( - device_state->compute_stream(), - [root_buffer, allocator{client_->allocator()}, device_ordinal]() { - TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); - }); - } - } else { - outputs.push_back(OutputBufferHelper(&result_buffer, definition_event, - client_, device, device_state)); - } + std::vector> outputs = + MakeOutputBuffers(device_ordinal, options, std::move(result_buffer), + definition_event, device); for (PjRtBuffer::ScopedHold& b : device_buffers) { // prefer_to_retain_reference=false because when using the @@ -1801,7 +1829,7 @@ StatusOr>> PjRtExecutable::Execute( StatusOr>> PjRtExecutable::ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, + absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options) const { if (device_assignment_ == nullptr) { VLOG(1) << "Executing portable single-core program on " @@ -1867,7 +1895,7 @@ PjRtExecutable::ExecuteOnLocalDevices( for (int i = 0; i < num_local_devices; ++i) { const int replica = local_logical_device_ids_[i].first; const int partition = local_logical_device_ids_[i].second; - Device* device = local_devices_[i]; + PjRtDevice* device = local_devices_[i]; const LocalDeviceState& device_state = *device->local_device_state(); device_state.execute_thread()->Schedule([&, replica, partition, i] { results[i] = ExecuteHelper(argument_handles[i], replica, partition, @@ -2114,12 +2142,12 @@ StatusOr, Shape>> GetShardedProgramShapes( build_options.set_result_layout(result_layout); std::vector> local_logical_device_ids; - std::vector local_devices; + std::vector local_devices; if (device_assignment != nullptr) { for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); - Device* device = LookupDevice(*client, device_id); + PjRtDevice* device = LookupDevice(*client, device_id); if (device->host_id() != client->host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index bb9093a8bf7..39711534f79 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -52,17 +52,18 @@ namespace xla { class PjRtClient; -class Device { +class PjRtDevice { public: - explicit Device(int id, std::unique_ptr local_device_state, - std::string platform_name, std::string device_kind, - int host_id = 0) + explicit PjRtDevice(int id, + std::unique_ptr local_device_state, + std::string platform_name, std::string device_kind, + int host_id = 0) : id_(id), local_device_state_(std::move(local_device_state)), host_id_(host_id), platform_name_(std::move(platform_name)), device_kind_(std::move(device_kind)) {} - virtual ~Device() {} + virtual ~PjRtDevice() {} // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all @@ -130,7 +131,7 @@ class PjRtClient { // `allocator` may null, in which case the platform default allocator is used. explicit PjRtClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -142,11 +143,15 @@ class PjRtClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() const { + const std::vector>& devices() const { return devices_; } - const std::vector& local_devices() const { return local_devices_; } - const std::map& id_to_device() const { return id_to_device_; } + const std::vector& local_devices() const { + return local_devices_; + } + const std::map& id_to_device() const { + return id_to_device_; + } int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } @@ -210,11 +215,11 @@ class PjRtClient { std::unique_ptr host_memory_allocator_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> devices_; // Maps Device::id() to the corresponding Device. Includes all devices. - std::map id_to_device_; + std::map id_to_device_; // Local devices indexed by local device ordinal. - std::vector local_devices_; + std::vector local_devices_; int host_id_; se::DeviceMemoryAllocator* allocator_; @@ -233,7 +238,7 @@ class PjRtClient { // Converts a 2D set of Device objects indexed by [replica][partition] into an // xla::DeviceAssignment. StatusOr DevicesToDeviceAssignment( - absl::Span> devices); + absl::Span> devices); // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer // can be either valid or invalid. An invalid buffer is one that has never been @@ -417,7 +422,7 @@ class PjRtBuffer { // Returns a buffer with uninitialized contents. static StatusOr> CreateUninitialized( - const Shape& shape, PjRtClient* client, Device* device); + const Shape& shape, PjRtClient* client, PjRtDevice* device); // Describes the semantics the caller to FromHostBuffer expects from the // runtime, in a total order from most restrictive to least restrictive. @@ -449,13 +454,13 @@ class PjRtBuffer { const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, std::shared_ptr buffer_reference, PjRtClient* client, - Device* device); + PjRtDevice* device); // Note that literal must remain in scope until the transfer has completed, so // the caller should, for example, wait for BlockHostUntilReady() completes on // the return value before letting literal go out of scope. static StatusOr> FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, Device* device); + const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device); // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact @@ -467,12 +472,13 @@ class PjRtBuffer { // sending host and used in a call to CopyToRemoteDevice. None of the recv // buffers will become ready until *all* of the sends have completed. static void MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtClient* client, Device* device, + PjRtClient* client, + PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier); PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - PjRtClient* client, Device* device); + PjRtClient* client, PjRtDevice* device); ~PjRtBuffer(); PjRtBuffer(const PjRtBuffer&) = delete; @@ -482,7 +488,7 @@ class PjRtBuffer { const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } - Device* device() const { return device_; } + PjRtDevice* device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } PjRtClient* client() const { return client_; } bool IsEmptyTuple() const { @@ -556,7 +562,7 @@ class PjRtBuffer { // Copies the buffer to device `dst_device`. Returns an error if the buffer is // already on dst_device. - StatusOr> CopyToDevice(Device* dst_device); + StatusOr> CopyToDevice(PjRtDevice* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the @@ -629,7 +635,7 @@ class PjRtBuffer { StatusOr, std::shared_ptr>> - CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device, + CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, std::shared_ptr src_device_buffer); @@ -637,7 +643,7 @@ class PjRtBuffer { PjRtClient* const client_; const Shape on_host_shape_; const Shape on_device_shape_; - Device* const device_; + PjRtDevice* const device_; mutable absl::Mutex mu_; std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); @@ -668,6 +674,11 @@ struct CompileOptions { bool compile_portable_executable = false; }; +class ExecuteContext { + public: + virtual ~ExecuteContext() = default; +}; + struct ExecuteOptions { // If true, the client must pass a single PjRtBuffer which contains all of // the arguments as a single XLA tuple, otherwise each argument must be @@ -682,6 +693,9 @@ struct ExecuteOptions { // multi-host programs are launched in different orders on different hosts, // the launch IDs may be used by the runtime to detect the mismatch. int32 launch_id = 0; + // If non-null, an opaque context passed to an execution that may be used to + // supply additional arguments to a derived class of PjRtExecutable. + const ExecuteContext* context = nullptr; }; // Represents a compiled computation that can be executed given handles to @@ -699,7 +713,7 @@ class PjRtExecutable { bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector> local_logical_device_ids, - std::vector local_devices, PjRtClient* client); + std::vector local_devices, PjRtClient* client); virtual ~PjRtExecutable() = default; @@ -733,14 +747,16 @@ class PjRtExecutable { return local_logical_device_ids_; } - const std::vector& local_devices() const { return local_devices_; } + const std::vector& local_devices() const { + return local_devices_; + } StatusOr>> Execute( absl::Span argument_handles, const ExecuteOptions& options) const; StatusOr>> ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, + absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options) const; // Execute on local devices. Takes a sequence of argument lists (one argument @@ -756,22 +772,42 @@ class PjRtExecutable { const string& name() const; + protected: + bool parameter_is_tupled_arguments() const { + return parameter_is_tupled_arguments_; + } + private: // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. Status SetUpDonation(PjRtClient* client, bool tuple_inputs); + virtual bool MustDonateParameter(int executable_idx, int parameter) const; + + virtual StatusOr> + MakeExecutionInputsAndWaitForEvents( + int device_ordinal, const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span device_buffers, + absl::flat_hash_set& events) const; + StatusOr EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, - const ExecuteOptions& options, Device* device, + const ExecuteOptions& options, PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const; + virtual std::vector> MakeOutputBuffers( + int device_ordinal, const ExecuteOptions& options, + ScopedShapedBuffer result_buffer, + std::shared_ptr definition_event, + PjRtDevice* device) const; + StatusOr>> ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, - Device* device = nullptr) const; + PjRtDevice* device = nullptr) const; // Create shared pointers so we can free them after the execution: with // asynchronous execution, the process being executed can outlive the @@ -800,7 +836,7 @@ class PjRtExecutable { // assigned. // shared_ptrs instead of unique_ptrs to play well with the Python bindings // (see xla.cc). - std::vector local_devices_; + std::vector local_devices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2143d1dfbe7..c932469c56a 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -112,6 +112,21 @@ xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { } } +xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::S8; + case 16: + return xla::S16; + case 32: + return xla::S32; + case 64: + return xla::S64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 034c14e8930..1228b4f9a32 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -153,6 +153,8 @@ int BitWidth(PrimitiveType type); PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); +PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index aa55a39218d..6ad1d789d48 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -6,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "pybind_extension") package( - default_visibility = ["//tensorflow:internal"], + default_visibility = [ + "//learning/pathways/data_parallel/jax:__subpackages__", + "//tensorflow:internal", + ], licenses = ["notice"], # Apache 2.0 ) @@ -155,7 +158,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:bfloat16", "//tensorflow/core/platform:logging", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", # buildcleaner: keep @@ -242,6 +245,34 @@ cc_library( ], ) +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//visibility:private"], + deps = [ + ":py_client", + ":pytree", + ":types", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@pybind11", + ], +) + cc_library( name = "ops", srcs = ["ops.cc"], @@ -257,6 +288,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:lu_decomposition", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", @@ -327,6 +359,27 @@ cc_library( ], ) +# TODO(phawkins): this library is really part of JAX. Find a better home for it. +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@pybind11", + ], +) + config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -346,8 +399,10 @@ pybind_extension( deps = [ ":bfloat16", ":dlpack", + ":jax_jit", ":ops", ":py_client", + ":pytree", ":python_ref_manager", ":outfeed_receiver_py", ":traceback", diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index 1f21b3fb242..b70244cc3ef 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -27,7 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 4fc17172ea7..67afa25d23e 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -193,7 +193,7 @@ StatusOr> StridesToLayout(absl::Span dims, return minor_to_major; } -StatusOr DLDeviceTypeForDevice(const Device& device) { +StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { const se::Platform* platform = device.local_device_state()->executor()->platform(); if (platform->id() == se::host::kHostPlatformId) { @@ -205,15 +205,15 @@ StatusOr DLDeviceTypeForDevice(const Device& device) { device.DebugString()); } -StatusOr DLContextForDevice(const Device& device) { +StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); context.device_id = device.local_device_state()->device_ordinal(); return context; } -StatusOr DeviceForDLContext(const PjRtClient& client, - const DLContext& context) { +StatusOr DeviceForDLContext(const PjRtClient& client, + const DLContext& context) { se::Platform::Id platform_id; switch (context.device_type) { case kDLCPU: @@ -226,7 +226,7 @@ StatusOr DeviceForDLContext(const PjRtClient& client, return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); } - auto it = absl::c_find_if(client.local_devices(), [&](Device* device) { + auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) { return device->local_device_state()->executor()->platform()->id() == platform_id && device->local_device_state()->device_ordinal() == context.device_id; @@ -313,7 +313,7 @@ StatusOr> DLPackManagedTensorToBuffer( dlmt->dl_tensor.ndim); } TF_ASSIGN_OR_RETURN( - Device * device, + PjRtDevice * device, DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx)); absl::Span dimensions( reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); @@ -321,7 +321,8 @@ StatusOr> DLPackManagedTensorToBuffer( DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); std::vector minor_to_major; - if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) { + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { absl::Span strides( reinterpret_cast(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim); diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc new file mode 100644 index 00000000000..2c364573e5b --- /dev/null +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -0,0 +1,830 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "tensorflow/compiler/xla/python/jax_jit.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/notification.h" +#include "absl/types/optional.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/py_buffer.h" +#include "tensorflow/compiler/xla/python/py_executable.h" +#include "tensorflow/compiler/xla/python/pytree.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace xla { + +namespace py = pybind11; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl Status. + +namespace { + +thread_local bool disable_jit; +void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; } +bool GetDisableJit() { return disable_jit; } + +// Describes the abstract shape and dtype of an argument. +struct ArgSignature { + // This is the XLA dtype of the object. + xla::PrimitiveType dtype; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + bool weak_type; + absl::InlinedVector shape; + bool operator==(const ArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const ArgSignature& other) const { return !(*this == other); } + + std::string DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; + } +}; + +template +H AbslHashValue(H h, const ArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + if (!s.shape.empty()) { + h = H::combine_contiguous(std::move(h), &s.shape.front(), s.shape.size()); + } + return h; +} + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + struct KwargEntry { + // To avoid comparing strings, we intern the kwargs strings. + // The compilation cache holds a reference to all the keys. + py::handle key; + PyTreeDef value_treedef; + bool operator==(const KwargEntry& other) const { + return key.ptr() == other.key.ptr() && + value_treedef == other.value_treedef; + } + bool operator!=(const KwargEntry& other) const { return !(*this == other); } + }; + + // Only contains the arguments associated to `static_argnums`, sorted in the + // order of their argnum index. + std::vector static_args; + // A PyTreeDef for each positional dynamic (i.e. not static) argument. + std::vector dynamic_positional_args_treedef; + // Keyword arguments. Sorted by the interned keyword pointers. + std::vector keyword_args; + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by interned keyword pointers). + std::vector dynamic_args_signatures; + PjRtDevice* device; + + bool operator==(const CallSignature& other) const { + return std::tie(dynamic_positional_args_treedef, static_args, keyword_args, + dynamic_args_signatures, device) == + std::tie(other.dynamic_positional_args_treedef, other.static_args, + other.keyword_args, other.dynamic_args_signatures, + other.device); + } + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + // To be used when we want to keep ownership of Python values referenced by + // the `CallSignature` (i.e. when we insert an entry). + void IncRef() const; + // The destructor of the cache should call this on all entries. + void DecRef() const; + + std::string DebugString() const; +}; + +void CallSignature::IncRef() const { + for (const auto& kw : keyword_args) { + kw.key.inc_ref(); + } +} + +void CallSignature::DecRef() const { + for (const auto& kw : keyword_args) { + kw.key.dec_ref(); + } +} + +template +H AbslHashValue(H h, const CallSignature::KwargEntry& kw) { + h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef); + return h; +} + +template +H AbslHashValue(H h, const CallSignature& s) { + // /!\ important: We cannot include static arguments to the hash, because + // the py::object must be hashable for absl. We can try delegating to the + // Python __hash__, but there are many non-hashable Python types such as + // np.ndarray. + // TODO(jblespiau): We should either ban non-hashable objects from jit or we + // should hash them by object identity. + h = H::combine_contiguous(std::move(h), + &s.dynamic_positional_args_treedef.front(), + s.dynamic_positional_args_treedef.size()); + h = H::combine_contiguous(std::move(h), &s.keyword_args.front(), + s.keyword_args.size()); + h = H::combine_contiguous(std::move(h), &s.dynamic_args_signatures.front(), + s.dynamic_args_signatures.size()); + h = H::combine(std::move(h), s.device); + return h; +} + +std::string CallSignature::DebugString() const { + std::vector static_args_str; + static_args_str.reserve(static_args.size()); + for (auto& static_arg : static_args) { + static_args_str.emplace_back(py::cast(static_arg.str())); + } + + std::vector signature_str; + signature_str.reserve(dynamic_args_signatures.size()); + + for (auto& arg_signature : dynamic_args_signatures) { + signature_str.emplace_back(arg_signature.DebugString()); + } + std::vector tree_def_str; + signature_str.reserve(dynamic_positional_args_treedef.size()); + for (auto& tree_def : dynamic_positional_args_treedef) { + tree_def_str.emplace_back(tree_def.ToString()); + } + std::vector keyword_names; + keyword_names.reserve(keyword_args.size()); + for (auto& kwarg_entry : keyword_args) { + keyword_names.emplace_back(py::cast(kwarg_entry.key)); + tree_def_str.emplace_back(kwarg_entry.value_treedef.ToString()); + } + return absl::StrCat( + static_args.size(), " static_args: ", absl::StrJoin(static_args_str, ","), + "\n", // new line + keyword_args.size(), " keyword args:", absl::StrJoin(keyword_names, ","), + "\n", // new-line + dynamic_positional_args_treedef.size(), " positional args.\n", + dynamic_args_signatures.size(), + " dynamic args (positional+keyword):\n - ", + absl::StrJoin(signature_str, ", "), "\n - ", + absl::StrJoin(tree_def_str, " | ")); +} + +struct CacheEntry { + std::shared_ptr executable; + xla::PjRtDevice* device; + PyTreeDef out_pytree_def; + // These are the objects required to create a `DeviceArray` object. + // We use Python types within the vector because this is what we will be + // returning to Python. No need to convert back and forth. + // We need py::object to maintain the objects alive. + std::vector out_avals; + std::vector out_lazy_exprs; + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been insterted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + absl::optional compilation_error = absl::nullopt; +}; + +// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyExecutable`. This class is thread-safe. +class CompiledFunction { + public: + CompiledFunction(py::function fun, py::function cache_miss_fun, + py::function python_f_jitted, bool jax_enable_x64, + bool jax_disable_jit, std::vector static_argnums); + ~CompiledFunction(); + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `DeviceArray` objects from the outputs + // (e) reconstruct the `PyTree`. + py::object Call(py::args args, py::kwargs kwargs); + + // This allows `inspect.signature(cpp_jitted_f)` from Python. + py::object __signature__() { + static const auto* inspect = new py::module(py::module::import("inspect")); + return inspect->attr("signature")(fun_); + } + + private: + CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return); + CacheEntry& SetAndReturnCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return = absl::nullopt); + bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_; } + + const py::function fun_; // The Python function to jit. + // The Python function in charge of returning a `xla::PyExecutable` from + // the arguments passed to `jitted_f`. + const py::function cache_miss_fun_; + // A function to call as fallback. This is the result of calling the Python + // `jax.jit`. + // TODO(jblespiau): Delete this when the C++ codepath supports all features. + const py::function python_f_jitted_; + + // The value of the Python flag when the object was created. + const bool jax_enable_x64_; + const bool jax_disable_jit_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyExecutable. In sorted order. + std::vector static_argnums_; + // We need a `unique_ptr` here to ensure value pointer stability. + absl::flat_hash_map> executables_; + + // As top-level functions are decorated with `jax.jit`, when + // `CompiledFunction` is being instantiated from Python, the clients are not + // yet available (done after GoogleInit). They will be during the first call + // to `Call`. + std::shared_ptr pyclient_ = nullptr; + xla::PjRtDevice* default_device_ = nullptr; + + // IMPORTANT: The GIL is not always held, because we call back to Python and + // Python will release the GIL. + // Thus, we protect the critical section modifying the `executables_` map + // and more generally the compilation with some `absl::Notification`. + // The first thread reaching such point will be responsible to create the + // notification for the executable and others will wait until notified. + // It's safe because the first thread will be holding the GIL while + // initializing the `Notification`. + // + // absl::optional is not supported + bool first_compilation_started_ = false; + absl::Notification first_compilation_complete_; + absl::optional first_compilation_error_ = absl::nullopt; +}; + +CompiledFunction::CompiledFunction(py::function fun, + py::function cache_miss_fun, + py::function python_f_jitted, + bool jax_enable_x64, bool jax_disable_jit, + std::vector static_argnums) + : fun_(std::move(fun)), + cache_miss_fun_(std::move(cache_miss_fun)), + python_f_jitted_(std::move(python_f_jitted)), + jax_enable_x64_(jax_enable_x64), + jax_disable_jit_(jax_disable_jit), + static_argnums_(std::move(static_argnums)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); +} + +CompiledFunction::~CompiledFunction() { + for (const auto& entry : executables_) { + entry.first.DecRef(); + } +} + +namespace { + +// The resulting information of the parsing and conversion of the arguments. +struct ParsedArgumentsAsBuffers { + // The call signature will be filled during 2 steps: + // - `FlattenArguments` will fill the static arguments and the pytree + // structures + // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. + CallSignature signature; + // The concatenation of the dynamic positional arguments and the sorted + // keyword arguments. We do not need ownership, thus the py::handle. + // TODO(jblespiau): We do not need py::object here and py::handle suffice and + // will prevent any counter increment. + std::vector flat_dynamic_args; + std::vector keep_alive_objects; + + // The following is only valid if the parsing succeeds. + std::vector arg_buffers; + // We may need to keep some objects around, because: + // (a) we need to extend the lifetime of objects created within + // `ConvertArgsToBuffers` + // (b) `arg_buffers` do not maintain ownership + std::vector, + std::unique_ptr>> + keep_alive; +}; + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments) { + arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() - + static_argnums.size()); + arguments.signature.dynamic_positional_args_treedef.reserve( + args.size() - static_argnums.size()); + + // Positional arguments. + for (size_t i = 0; i < args.size(); ++i) { + if (std::find(static_argnums.begin(), static_argnums.end(), i) == + static_argnums.end()) { + PyTreeDef pytree_def; + pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args); + arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def); + } else { + arguments.signature.static_args.emplace_back( + // borrow is mandatory here. + py::reinterpret_borrow(args[i])); + } + } + + // Keyword arguments. + std::vector> kwargs(py_kwargs.begin(), + py_kwargs.end()); + // We first intern the keys, then sort them (by pointer) and then create + // the signatures. + arguments.signature.keyword_args.resize(kwargs.size()); + for (size_t i = 0; i < kwargs.size(); ++i) { + // Intern the key if not already interned. + if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) { + PyObject* key = kwargs[i].first.ptr(); + kwargs[i].first.inc_ref(); + PyUnicode_InternInPlace(&key); + arguments.keep_alive_objects.push_back( + py::reinterpret_steal(key)); + kwargs[i].first = py::handle(key); + } + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first.ptr() < b.first.ptr(); + }); + for (size_t i = 0; i < kwargs.size(); ++i) { + arguments.signature.keyword_args[i].key = kwargs[i].first; + arguments.signature.keyword_args[i].value_treedef.FlattenInto( + kwargs[i].second, arguments.flat_dynamic_args); + } +} + +template +std::unique_ptr ConvertToScalarBuffer( + const py::handle& scalar, xla::PjRtClient* client, + xla::PjRtDevice* device) { + CppType data = py::cast(scalar); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + client, device)); +} + +// Convert a scalar to the associated PjRtBuffer or raises an error if it is +// not convertible (thus, this must be called after other checks). +StatusOr> ScalarToBuffer( + py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client, + xla::PjRtDevice* device) { + // Important: In Python, isinstance(True, int) returns True. Thus, we have + // to check for bool before int. + if (py::isinstance(scalar)) { + return ConvertToScalarBuffer(scalar, client, device); + } else if (py::isinstance(scalar)) { + if (jax_enable_x64) { + return ConvertToScalarBuffer(scalar, client, device); + } else { + return ConvertToScalarBuffer(scalar, client, device); + } + } else if (py::isinstance(scalar)) { + if (jax_enable_x64) { + return ConvertToScalarBuffer(scalar, client, device); + + } else { + return ConvertToScalarBuffer(scalar, client, device); + } + } else if (PyComplex_Check(scalar.ptr())) { + Py_complex result = PyComplex_AsCComplex(scalar.ptr()); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error("Could not convert the complex number"); + } + if (jax_enable_x64) { + xla::complex128 data(result.real, result.imag); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, client, device)); + } else { + xla::complex64 data(result.real, result.imag); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, client, device)); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays, or Python scalars. Got type ", + py::cast(scalar.get_type().str()))); +} + +const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { + static const auto* int64_dt = new py::dtype("int64"); + static const auto* int32_dt = new py::dtype("int32"); + static const auto* uint64_dt = new py::dtype("uint64"); + static const auto* uint32_dt = new py::dtype("uint32"); + static const auto* float64_dt = new py::dtype("float64"); + static const auto* float32_dt = new py::dtype("float32"); + static const auto* complex64_dt = new py::dtype("complex64"); + static const auto* complex128_dt = new py::dtype("complex128"); + + if (dtype == *int64_dt) { + return int32_dt; + } + if (dtype == *float64_dt) { + return float32_dt; + } + if (dtype == *uint64_dt) { + return uint32_dt; + } + if (dtype == *complex128_dt) { + return complex64_dt; + } + + return nullptr; +} + +// Converts flattened arguments contained in ParsedArgumentsAsBuffers in +// place. If arguments are `DeviceArray`, they must all be on the same `Device`. +// +// Returns `OkStatus()` on success. +Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, + xla::PjRtDevice* default_device, + ParsedArgumentsAsBuffers& arguments) { + std::vector& arg_buffers = arguments.arg_buffers; + auto& keep_alive = arguments.keep_alive; + + int num_flat_dynamic_args = arguments.flat_dynamic_args.size(); + arg_buffers.reserve(num_flat_dynamic_args); + arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args); + + static const auto* xla_module = + new py::module(py::module::import("jax.interpreters.xla")); + const auto& device_array = xla_module->attr("DeviceArray"); + + static const auto* numpy_module = new py::module(py::module::import("numpy")); + const auto& array = numpy_module->attr("array"); + + // TODO(phawkins): consider device stickiness. + // We first check whether any `DeviceArray` is present and whether they are + // attached to any specific device. See also + // https://github.com/google/jax/pull/1884 + // https://github.com/google/jax/pull/1916 for the rationale why the + // computation follows the data locality. + // It's also similar to PyTorch's behavior. + xla::PjRtDevice* data_device = nullptr; + for (py::handle arg : arguments.flat_dynamic_args) { + if (py::isinstance(arg, device_array)) { + xla::PyBuffer* buffer; + try { + // This can fail, e.g. when device_buffer is a `DeviceConstant`. + buffer = py::cast(arg.attr("device_buffer")); + } catch (const py::cast_error& e) { + return InvalidArgument( + "%s", + absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " + "`device_buffer` field is of type ", + py::cast( + arg.attr("device_buffer").get_type().str()), + " while a `PyBuffer` was expected." + + )); + } + xla::PjRtDevice* device = buffer->buffer()->device(); + if (data_device && (device != data_device)) { + return InvalidArgument( + "%s", + absl::StrCat( + "Arguments to a jit-compiled function must be colocated on the " + "same device. Arguments were found to be on the two following " + "different devices: ", + device->DebugString(), " and ", data_device->DebugString())); + } else { + data_device = device; + } + } + } + if (!data_device) { + // No `DeviceArray` were found default to `default_device`. + data_device = default_device; + } + CHECK(data_device); + arguments.signature.device = data_device; + xla::PjRtClient* pjrt_client = data_device->client(); + + for (py::handle arg : arguments.flat_dynamic_args) { + // We do not support here d2d transparent transfers. + // We assumes all the `DeviceArray` are already on the correct and shared + // device. + if (py::isinstance(arg, device_array)) { + xla::PyBuffer* buffer = + py::cast(arg.attr("device_buffer")); + arg_buffers.push_back(buffer->buffer()); + ArgSignature sig; + sig.dtype = buffer->shape().element_type(); + sig.shape.assign(buffer->shape().dimensions().begin(), + buffer->shape().dimensions().end()); + sig.weak_type = py::cast(arg.attr("aval").attr("weak_type")); + arguments.signature.dynamic_args_signatures.push_back(std::move(sig)); + } else if (py::isinstance(arg)) { + // TODO(jblespiau): Can we improve this call? Do we need the underlying + // GlobalPyRefManager() and co? + py::array numpy_array = py::cast(arg); + // If jax_enable_x64 is not set, we need to coerce 32 bits types. + // Note that this is calling back to Python! + if (!jax_enable_x64) { + const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype()); + if (to_dtype) { + numpy_array = array(numpy_array, to_dtype); + } + } + std::unique_ptr buffer = + ValueOrThrow(pyclient.BufferFromPyval( + numpy_array, data_device, + /*force_copy=*/false, /*host_buffer_semantics=*/ + xla::PjRtBuffer::HostBufferSemantics::kZeroCopy)); + arg_buffers.push_back(buffer->buffer()); + + ArgSignature sig; + sig.dtype = buffer->shape().element_type(); + sig.weak_type = false; + sig.shape.assign(buffer->shape().dimensions().begin(), + buffer->shape().dimensions().end()); + arguments.signature.dynamic_args_signatures.push_back(sig); + + keep_alive.emplace_back(std::move(buffer)); + } else { + StatusOr> buffer = + ScalarToBuffer(arg, jax_enable_x64, pjrt_client, data_device); + if (!buffer.ok()) { + return buffer.status(); + } + arg_buffers.push_back(buffer.ValueOrDie().get()); + ArgSignature sig; + sig.dtype = buffer.ValueOrDie()->on_host_shape().element_type(); + sig.weak_type = true; + arguments.signature.dynamic_args_signatures.push_back(sig); + + keep_alive.emplace_back(std::move(buffer).ValueOrDie()); + } + } + return Status::OK(); +} + +} // namespace + +CacheEntry& CompiledFunction::GetCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return) { + auto found_iterator = executables_.find(signature); + if (found_iterator != executables_.end()) { // Cache hit! + if (!found_iterator->second->compilation_complete.HasBeenNotified()) { + py::gil_scoped_release gil_release; + found_iterator->second->compilation_complete.WaitForNotification(); + if (found_iterator->second->compilation_error) { + throw found_iterator->second->compilation_error.value(); + } + } + return *(found_iterator->second); + } + return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return); +} +CacheEntry& CompiledFunction::SetAndReturnCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return) { + // We need to insert the element. + auto result = executables_.emplace(signature, std::make_unique()); + auto it = result.first; + CacheEntry& cache_entry = *(it->second.get()); + // CallSignatures in the cache own their keyword argument reference. + result.first->first.IncRef(); + + // Cache miss? Call the Python cache miss function. + py::tuple executable_and_pytree; + if (cache_miss_return) { + executable_and_pytree = cache_miss_return.value(); + } else { + try { + executable_and_pytree = cache_miss_fun_(*args, **kwargs); + } catch (const std::exception& e) { + cache_entry.compilation_error = e; + cache_entry.compilation_complete.Notify(); + throw; + } + } + if (executable_and_pytree.size() != 4) { + throw std::runtime_error( + "AssertionError: The cache miss function should return 4 " + "arguments."); + } + cache_entry.executable = py::cast>( + std::move(executable_and_pytree[0])); + int num_devices = + cache_entry.executable->pjrt_executable().local_devices().size(); + if (num_devices != 1) { + throw std::runtime_error(absl::StrCat( + "Running on more than a single device is not currently supported." + "The underlying PjRtExecutable has ", + num_devices)); + } + cache_entry.device = + cache_entry.executable->pjrt_executable().local_devices()[0]; + cache_entry.out_pytree_def = py::cast(executable_and_pytree[1]); + + py::list shaped_arrays = + py::reinterpret_borrow(executable_and_pytree[2]); + py::list lazy_expressions = + py::reinterpret_borrow(executable_and_pytree[3]); + + cache_entry.out_avals.reserve(shaped_arrays.size()); + cache_entry.out_lazy_exprs.reserve(lazy_expressions.size()); + + int num_outputs = shaped_arrays.size(); + for (int i = 0; i < num_outputs; ++i) { + py::object shaped_array = + py::reinterpret_borrow(shaped_arrays[i]); + py::object lazy_expr = + py::reinterpret_borrow(lazy_expressions[i]); + + cache_entry.out_avals.push_back(shaped_array); + cache_entry.out_lazy_exprs.push_back(lazy_expr); + } + + cache_entry.compilation_complete.Notify(); + return cache_entry; +} + +py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { + if (JitIsDisabled()) { + return fun_(*args, **kwargs); + } + ParsedArgumentsAsBuffers arguments; + FlattenArguments(args, kwargs, static_argnums_, arguments); + + // TODO(jblespiau): It would be preferable to have a single location for + // locking code. + absl::optional cache_miss_result = absl::nullopt; + if (!default_device_) { + // TODO(jblespiau): This code will deadlock if a jitted function + // recursively calls itself. + if (first_compilation_started_) { + if (!first_compilation_complete_.HasBeenNotified()) { + py::gil_scoped_release gil_release; + first_compilation_complete_.WaitForNotification(); + if (first_compilation_error_) { + throw first_compilation_error_.value(); + } + } + } else { + first_compilation_started_ = true; + try { + cache_miss_result = cache_miss_fun_(*args, **kwargs); + } catch (const std::exception& e) { + first_compilation_error_ = e; + first_compilation_complete_.Notify(); + throw; + } + auto executable = py::cast>( + cache_miss_result.value()[0]); + + pyclient_ = executable->client(); + default_device_ = executable->LocalDevices()[0].contents; + first_compilation_complete_.Notify(); + } + } + + // The C++ jit do not support Tracers arguments yet. The Python-based jit + // function will be called if any of the dynamic arguments is unsupported. + if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_, + arguments) + .ok()) { + return python_f_jitted_(*args, **kwargs); + } + + CacheEntry& cache_entry = + GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result); + + std::vector> outputs = + ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers)); + + static const auto* xla_module = + new py::module(py::module::import("jax.interpreters.xla")); + const auto& device_array = xla_module->attr("DeviceArray"); + + const std::vector& out_avals = cache_entry.out_avals; + const std::vector& out_lazy_exprs = cache_entry.out_lazy_exprs; + + py::list flat_device_arrays; + for (int i = 0; i < outputs.size(); ++i) { + flat_device_arrays.append(device_array( + /*aval=*/out_avals[i], /*device=*/outputs[i]->device(), + /*lazy_expr=*/out_lazy_exprs[i], + /*device_buffer=*/std::move(outputs[i]))); + } + return cache_entry.out_pytree_def.Unflatten(flat_device_arrays); +} + +} // namespace + +void BuildJaxjitSubmodule(pybind11::module& m) { + py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + py::class_> cfun( + jitlib, "CompiledFunction"); + cfun.def("__call__", &CompiledFunction::Call); + cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__); + + jitlib.def("set_disable_jit", &SetDisableJit); + jitlib.def("get_disable_jit", &GetDisableJit); + jitlib.def( + "jit", + [](py::function fun, py::function cache_miss_fun, + py::function fallback_on_unsupported_argument, bool jax_enable_x64, + bool jax_disable_jit, + std::vector static_argnums) -> std::unique_ptr { + return std::make_unique( + std::move(fun), std::move(cache_miss_fun), + std::move(fallback_on_unsupported_argument), jax_enable_x64, + jax_disable_jit, std::move(static_argnums)); + }); + + // Only for testing purposes + jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64, + std::shared_ptr client) { + xla::PjRtClient* pjrt_client = client->pjrt_client(); + + return std::make_unique( + client, + ScalarToBuffer(scalar, jax_enable_x64, pjrt_client, + pjrt_client->local_devices()[0]) + .ValueOrDie(), + nullptr); + }); +} + +} // namespace xla diff --git a/tensorflow/python/util/tf32.cc b/tensorflow/compiler/xla/python/jax_jit.h similarity index 74% rename from tensorflow/python/util/tf32.cc rename to tensorflow/compiler/xla/python/jax_jit.h index 7dece6ccdae..2b1603aac27 100644 --- a/tensorflow/python/util/tf32.cc +++ b/tensorflow/compiler/xla/python/jax_jit.h @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" -#include "tensorflow/core/platform/tf32_utils.h" +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ -PYBIND11_MODULE(_pywrap_tf32_execution, m) { - m.def("allow", &tensorflow::allow_tf32_execution); - m.def("is_allowed", &tensorflow::tf32_execution_allowed); -} +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildJaxjitSubmodule(pybind11::module& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc index 3ac4709b160..f8099412c73 100644 --- a/tensorflow/compiler/xla/python/ops.cc +++ b/tensorflow/compiler/xla/python/ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "pybind11/attr.h" #include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" @@ -186,6 +187,13 @@ void BuildOpsSubmodule(py::module* m) { return std::make_pair(qr.q, qr.r); }, py::arg("operand"), py::arg("full_matrices")); + ops.def( + "LU", + [](XlaOp a) -> StatusOr> { + LuDecompositionResult lu = LuDecomposition(a); + return std::make_tuple(lu.lu, lu.pivots, lu.permutation); + }, + py::arg("operand")); ops.def( "Eigh", [](XlaOp a, bool lower, int64 max_iter, diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 7c029ca7d19..f6067e650c0 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -101,14 +101,14 @@ uint32_t constexpr kOutfeedCidShutdown = 0; // Encapsulates data received from a device outfeed. class OutfeedData { public: - OutfeedData(Device* device, uint32_t consumer_id, Shape shape) + OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape) : device_(device), consumer_id_(consumer_id), shape_(shape), literal_(nullptr), literal_size_bytes_(0) {} - Device* device() { return device_; } + PjRtDevice* device() { return device_; } uint32_t consumer_id() const { return consumer_id_; } Shape shape() const { return shape_; } std::unique_ptr literal() { @@ -123,7 +123,7 @@ class OutfeedData { std::string DebugString() const; private: - Device* device_; + PjRtDevice* device_; uint32_t consumer_id_; Shape shape_; std::unique_ptr literal_; @@ -187,8 +187,8 @@ class OutfeedReceiverImpl { Status SendShutdownOutfeedHeader(int device_idx); // Receives a raw Literal from a device outfeed. - StatusOr> ReceiveRawFromOutfeed(const Device* device, - const Shape& shape); + StatusOr> ReceiveRawFromOutfeed( + const PjRtDevice* device, const Shape& shape); // Enqueues received data in the callbaback queue. void EnqueueReceivedData(std::unique_ptr received) @@ -200,7 +200,7 @@ class OutfeedReceiverImpl { OutfeedReceiver::Callback callback_; // The devices on which we are listening. - std::vector devices_; + std::vector devices_; // Maximum bytes capacity of the callback queue. uint64_t max_callback_queue_size_bytes_; @@ -283,7 +283,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) { absl::MutexLock lock(&mu_); ++num_listening_threads_; } - Device* device = devices_[device_idx]; + PjRtDevice* device = devices_[device_idx]; while (true) { Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}); std::unique_ptr header = @@ -339,7 +339,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData( } StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( - const Device* device, const Shape& shape) { + const PjRtDevice* device, const Shape& shape) { std::shared_ptr literal_shared; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, @@ -390,7 +390,7 @@ void OutfeedReceiverImpl::CallbackThreadLoop() { } Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { - const Device* device = devices_[device_idx]; + const PjRtDevice* device = devices_[device_idx]; constexpr int consumer_id = kOutfeedCidShutdown; VLOG(2) << "[" << device->DebugString() << "] SendSpecialHeader cons=" << consumer_id; diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.h b/tensorflow/compiler/xla/python/outfeed_receiver.h index a8dcc559810..46e2e5d9526 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.h +++ b/tensorflow/compiler/xla/python/outfeed_receiver.h @@ -33,7 +33,7 @@ class OutfeedReceiver { public: // A callback takes: device, consumer id, received. using Callback = - std::function)>; + std::function)>; // Constructs the receiver for the given clients and callback function. // diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc index d297df332ff..a732ab8e21a 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc @@ -40,7 +40,7 @@ class OutfeedReceiverForPython { public: // A callback to Python takes: consumer id, received literal. using CallbackToPython = - std::function, uint32_t, pybind11::object)>; + std::function, uint32_t, pybind11::object)>; OutfeedReceiverForPython(CallbackToPython callback_python, std::vector> clients, @@ -48,7 +48,7 @@ class OutfeedReceiverForPython { : callback_python_(std::move(callback_python)), clients_(std::move(clients)) { OutfeedReceiver::Callback callback = - [this](Device* device, uint32_t consumer_id, + [this](PjRtDevice* device, uint32_t consumer_id, std::shared_ptr literal) { this->Callback(device, consumer_id, std::move(literal)); }; @@ -86,7 +86,7 @@ class OutfeedReceiverForPython { arrays); } - void Callback(Device* device, uint32_t consumer_id, + void Callback(PjRtDevice* device, uint32_t consumer_id, std::shared_ptr literal) { { absl::MutexLock lock(&mu_); @@ -106,7 +106,7 @@ class OutfeedReceiverForPython { LiteralToPython(std::move(literal)).ValueOrDie(); // The callback_ should handle all exceptions in user-code. If we get // an exception here, it is a bug in the callback and we should stop. - callback_python_(WrapWithClient(*it, device), consumer_id, + callback_python_(WrapWithClient(*it, device), consumer_id, std::move(literal_python)); } diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index e8a5063b70b..919dafe2e0b 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -78,11 +78,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -111,11 +111,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -156,11 +156,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -199,11 +199,11 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -233,11 +233,11 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index ed4787310b4..b32fe047530 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -51,12 +51,12 @@ PyBuffer::~PyBuffer() { } } -ClientAndPtr PyBuffer::device() const { +ClientAndPtr PyBuffer::device() const { return WrapWithClient(client_, buffer_->device()); } StatusOr> PyBuffer::CopyToDevice( - const ClientAndPtr& dst_device) const { + const ClientAndPtr& dst_device) const { CHECK(dst_device.get() != nullptr); GlobalPyRefManager()->CollectGarbage(); std::unique_ptr out; diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 76791e969cb..d7906574ec1 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -38,12 +38,12 @@ class PyBuffer { std::shared_ptr client() const { return client_; } PjRtBuffer* buffer() const { return buffer_.get(); } - ClientAndPtr device() const; + ClientAndPtr device() const; const std::string& platform_name() const { return buffer_->platform_name(); } bool is_deleted() const { return buffer_->IsDeleted(); } StatusOr> CopyToDevice( - const ClientAndPtr& dst_device) const; + const ClientAndPtr& dst_device) const; void Delete() { return buffer_->Delete(); } diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 1f07c6e2042..6df11322564 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/python/py_client.h" +#include + #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/python/py_buffer.h" #include "tensorflow/compiler/xla/python/py_executable.h" @@ -31,8 +33,8 @@ namespace pprof = tensorflow::tfprof::pprof; PyClient::PyClient(std::shared_ptr pjrt_client) : pjrt_client_(std::move(pjrt_client)) {} -std::vector> PyClient::Devices() { - std::vector> devices; +std::vector> PyClient::Devices() { + std::vector> devices; devices.reserve(pjrt_client_->devices().size()); for (const auto& device : pjrt_client_->devices()) { devices.push_back(WrapWithClient(shared_from_this(), device.get())); @@ -40,21 +42,21 @@ std::vector> PyClient::Devices() { return devices; } -std::vector> PyClient::LocalDevices() { - std::vector> devices; +std::vector> PyClient::LocalDevices() { + std::vector> devices; devices.reserve(pjrt_client_->local_devices().size()); - for (Device* device : pjrt_client_->local_devices()) { + for (PjRtDevice* device : pjrt_client_->local_devices()) { devices.push_back(WrapWithClient(shared_from_this(), device)); } return devices; } -StatusOr>>> +StatusOr>>> PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) { TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions)); - std::vector>> result; + std::vector>> result; result.resize(num_replicas); for (int r = 0; r < num_replicas; ++r) { result[r].resize(num_partitions); @@ -68,12 +70,12 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) { return result; } -StatusOr>> +StatusOr>> PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, pjrt_client_->GetDefaultDeviceAssignment( num_replicas, /*num_partitions=*/1)); - std::vector> result; + std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); auto iter = pjrt_client_->id_to_device().find(device_id); @@ -84,7 +86,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { } StatusOr> PyClient::BufferFromPyval( - const pybind11::object& argument, Device* device, bool force_copy, + const pybind11::object& argument, PjRtDevice* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { TF_RET_CHECK(!pjrt_client_->local_devices().empty()); @@ -104,7 +106,6 @@ StatusOr> PyClient::BufferFromPyval( return InvalidArgument("from_python argument must be an array."); } - TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument)); std::shared_ptr py_buffer_ref = GlobalPyRefManager()->ManageReference(std::move(c->array)); @@ -121,7 +122,7 @@ StatusOr> PyClient::BufferFromPyval( std::move(traceback)); } -StatusOr> PyClient::Compile( +StatusOr> PyClient::Compile( const XlaComputation& computation, CompileOptions options) { std::unique_ptr executable; absl::optional fingerprint; @@ -134,7 +135,7 @@ StatusOr> PyClient::Compile( pjrt_client_->ExecutableFingerprint(*executable)); } auto traceback = Traceback::Get(); - return std::make_unique( + return std::make_shared( shared_from_this(), std::move(executable), std::move(traceback), std::move(fingerprint)); } @@ -205,7 +206,7 @@ namespace { struct HeapProfileKey { Traceback* traceback; int64 size; - Device* device; + PjRtDevice* device; bool operator==(const HeapProfileKey& other) const; }; diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index d33f3dadd7d..f12a4ae4f0a 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -100,14 +100,14 @@ class PyClient : public std::enable_shared_from_this { int device_count() const { return pjrt_client_->device_count(); } int host_id() const { return pjrt_client_->host_id(); } - std::vector> Devices(); - std::vector> LocalDevices(); + std::vector> Devices(); + std::vector> LocalDevices(); - StatusOr>>> + StatusOr>>> GetDefaultDeviceAssignment(int num_replicas, int num_partitions); // TODO(skye): delete after all callers can handle 2D output - StatusOr>> GetDefaultDeviceAssignment1D( + StatusOr>> GetDefaultDeviceAssignment1D( int num_replicas); StatusOr CreateChannelHandle() { @@ -121,10 +121,10 @@ class PyClient : public std::enable_shared_from_this { } StatusOr> BufferFromPyval( - const pybind11::object& argument, Device* device, bool force_copy, + const pybind11::object& argument, PjRtDevice* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics); - StatusOr> Compile( + StatusOr> Compile( const XlaComputation& computation, CompileOptions options); pybind11::bytes HeapProfile(); diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index b2cd2af56ea..53891b96846 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -37,7 +37,9 @@ PyExecutable::PyExecutable(std::shared_ptr client, if (next_) { next_->prev_ = this; } + options_.untuple_result = true; if (fingerprint_) { + options_.launch_id = tensorflow::Fingerprint32(*fingerprint_); VLOG(1) << "Fingerprint for executable " << executable_->name() << ": " << *fingerprint_; } @@ -56,30 +58,42 @@ PyExecutable::~PyExecutable() { } } -std::vector> PyExecutable::LocalDevices() const { - std::vector> devices; +std::vector> PyExecutable::LocalDevices() const { + std::vector> devices; devices.reserve(executable_->local_devices().size()); - for (Device* device : executable_->local_devices()) { + for (PjRtDevice* device : executable_->local_devices()) { devices.push_back(WrapWithClient(client_, device)); } return devices; } +StatusOr>> PyExecutable::PjRtExecute( + absl::Span args) { + std::vector> output_buffers; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute(args, options_)); + } + auto traceback = Traceback::Get(); + std::vector> outputs; + outputs.reserve(output_buffers.size()); + for (auto& buffer : output_buffers) { + outputs.push_back( + std::make_unique(client_, std::move(buffer), traceback)); + } + return outputs; +} + StatusOr>> PyExecutable::Execute( absl::Span args) { std::vector> output_buffers; { py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - if (fingerprint_) { - options.launch_id = tensorflow::Fingerprint32(*fingerprint_); - } std::vector arg_buffers(args.size()); absl::c_transform(args, arg_buffers.begin(), [](PyBuffer* buf) { return buf->buffer(); }); TF_ASSIGN_OR_RETURN(output_buffers, - executable_->Execute(arg_buffers, options)); + executable_->Execute(arg_buffers, options_)); } auto traceback = Traceback::Get(); std::vector> outputs; @@ -97,11 +111,6 @@ PyExecutable::ExecuteOnLocalDevices( std::vector>> output_buffers; { py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - if (fingerprint_) { - options.launch_id = tensorflow::Fingerprint32(*fingerprint_); - } std::vector> arg_buffers(args.size()); for (int computation = 0; computation < args.size(); ++computation) { arg_buffers[computation].resize(args[computation].size()); @@ -109,7 +118,7 @@ PyExecutable::ExecuteOnLocalDevices( [](PyBuffer* buf) { return buf->buffer(); }); } TF_ASSIGN_OR_RETURN(output_buffers, executable_->ExecuteOnLocalDevices( - arg_buffers, options)); + arg_buffers, options_)); } auto traceback = Traceback::Get(); std::vector>> outputs; diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h index 1051d065335..2e51548ae51 100644 --- a/tensorflow/compiler/xla/python/py_executable.h +++ b/tensorflow/compiler/xla/python/py_executable.h @@ -47,7 +47,7 @@ class PyExecutable { return executable_->local_logical_device_ids(); } - std::vector> LocalDevices() const; + std::vector> LocalDevices() const; int64 SizeOfGeneratedCodeInBytes() const { return executable_->SizeOfGeneratedCodeInBytes(); @@ -58,6 +58,10 @@ class PyExecutable { StatusOr>> Execute( absl::Span args); + // Same as above, but take as inputs `PjRtBuffer*`. Only targets C++ code. + StatusOr>> PjRtExecute( + absl::Span args); + StatusOr>>> ExecuteOnLocalDevices(absl::Span> args); @@ -65,6 +69,8 @@ class PyExecutable { Traceback* traceback() { return traceback_.get(); } + const PjRtExecutable& pjrt_executable() const { return *executable_; } + private: friend class PyClient; @@ -77,6 +83,9 @@ class PyExecutable { // aren't implemented. absl::optional fingerprint_; + // The options to pass to `executable_.Execute`. + ExecuteOptions options_; + // Doubly-linked list of all executables known to the client. Protected by the // GIL. PyExecutable* next_; diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc new file mode 100644 index 00000000000..bf0bb1a8d93 --- /dev/null +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -0,0 +1,648 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "tensorflow/compiler/xla/python/pytree.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace xla { + +namespace py = pybind11; + +/*static*/ CustomNodeRegistry* CustomNodeRegistry::Singleton() { + static auto* registry = new CustomNodeRegistry; + return registry; +} + +/*static*/ void CustomNodeRegistry::Register(py::object type, + py::function to_iterable, + py::function from_iterable) { + CustomNodeRegistry* registry = Singleton(); + auto registration = absl::make_unique(); + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + auto it = registry->registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + py::repr(type))); + } +} + +/*static*/ const CustomNodeRegistry::Registration* CustomNodeRegistry::Lookup( + py::handle type) { + CustomNodeRegistry* registry = Singleton(); + auto it = + registry->registrations_.find(py::reinterpret_borrow(type)); + return it == registry->registrations_.end() ? nullptr : it->second.get(); +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + a.custom != b.custom) { + return false; + } + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +/*static*/ PyTreeDef::Kind PyTreeDef::GetKind( + const py::handle& obj, CustomNodeRegistry::Registration const** custom) { + const PyObject* ptr = obj.ptr(); + if (PyTuple_CheckExact(ptr)) return Kind::kTuple; + if (PyList_CheckExact(ptr)) return Kind::kList; + if (PyDict_CheckExact(ptr)) return Kind::kDict; + if ((*custom = CustomNodeRegistry::Lookup(obj.get_type()))) { + return Kind::kCustom; + } else if (py::isinstance(obj)) { + return Kind::kNone; + } else if (py::isinstance(obj) && py::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return Kind::kNamedTuple; + } else { + return Kind::kLeaf; + } +} + +void PyTreeDef::FlattenInto(py::handle handle, + std::vector& leaves) { + Node node; + int start_num_nodes = traversal_.size(); + int start_num_leaves = leaves.size(); + node.kind = GetKind(handle, &node.custom); + if (node.kind == Kind::kNone) { + // Nothing to do. + } else if (node.kind == Kind::kTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + for (py::handle entry : tuple) { + FlattenInto(entry, leaves); + } + } else if (node.kind == Kind::kList) { + py::list list = py::reinterpret_borrow(handle); + node.arity = list.size(); + for (py::handle entry : list) { + FlattenInto(entry, leaves); + } + } else if (node.kind == Kind::kDict) { + py::dict dict = py::reinterpret_borrow(handle); + py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); + if (PyList_Sort(keys.ptr())) { + throw std::runtime_error("Dictionary key sort failed."); + } + for (py::handle key : keys) { + FlattenInto(dict[key], leaves); + } + node.arity = dict.size(); + node.node_data = std::move(keys); + } else if (node.kind == Kind::kCustom) { + py::tuple out = py::cast(node.custom->to_iterable(handle)); + if (out.size() != 2) { + throw std::runtime_error( + "PyTree custom to_iterable function should return a pair"); + } + node.node_data = out[1]; + node.arity = 0; + for (py::handle entry : py::cast(out[0])) { + ++node.arity; + FlattenInto(entry, leaves); + } + } else if (node.kind == Kind::kNamedTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + node.node_data = py::reinterpret_borrow(tuple.get_type()); + for (py::handle entry : tuple) { + FlattenInto(entry, leaves); + } + } else { + assert(node.kind == Kind::kLeaf); + leaves.push_back(pybind11::reinterpret_borrow(handle)); + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +/*static*/ std::pair, std::unique_ptr> +PyTreeDef::Flatten(py::handle x) { + std::vector leaves; + auto tree = absl::make_unique(); + tree->FlattenInto(x, leaves); + return std::make_pair(std::move(leaves), std::move(tree)); +} + +/*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) { + const CustomNodeRegistry::Registration* custom; + for (const py::handle& h : x) { + if (GetKind(h, &custom) != Kind::kLeaf) return false; + } + return true; +} + +py::object PyTreeDef::Unflatten(py::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case Kind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(py::reinterpret_borrow(*it)); + ++it; + ++leaf_count; + break; + + case Kind::kNone: + case Kind::kTuple: + case Kind::kNamedTuple: + case Kind::kList: + case Kind::kDict: + case Kind::kCustom: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + py::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +/*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case Kind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case Kind::kNone: + return py::none(); + + case Kind::kTuple: + case Kind::kNamedTuple: { + py::tuple tuple(node.arity); + for (int i = 0; i < node.arity; ++i) { + tuple[i] = std::move(children[i]); + } + if (node.kind == Kind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return std::move(tuple); + } + } + + case Kind::kList: { + py::list list(node.arity); + for (int i = 0; i < node.arity; ++i) { + list[i] = std::move(children[i]); + } + return std::move(list); + } + + case Kind::kDict: { + py::dict dict; + py::list keys = py::reinterpret_borrow(node.node_data); + for (int i = 0; i < node.arity; ++i) { + dict[keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case Kind::kCustom: { + py::tuple tuple(node.arity); + for (int i = 0; i < node.arity; ++i) { + tuple[i] = std::move(children[i]); + } + return node.custom->from_iterable(node.node_data, tuple); + } + } + throw std::logic_error("Unreachable code."); +} + +py::list PyTreeDef::FlattenUpTo(py::handle xs) const { + py::list leaves(num_leaves()); + std::vector agenda; + agenda.push_back(py::reinterpret_borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", py::repr(xs), ToString())); + } + const Node& node = *it; + py::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case Kind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + leaves[leaf] = py::reinterpret_borrow(object); + --leaf; + break; + + case Kind::kNone: + break; + + case Kind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", py::repr(object))); + } + py::tuple tuple = py::reinterpret_borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument( + absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.", + tuple.size(), node.arity, py::repr(object))); + } + for (py::handle entry : tuple) { + agenda.push_back(py::reinterpret_borrow(entry)); + } + break; + } + + case Kind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", py::repr(object))); + } + py::list list = py::reinterpret_borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument( + absl::StrFormat("List arity mismatch: %d != %d; list: %s.", + list.size(), node.arity, py::repr(object))); + } + for (py::handle entry : list) { + agenda.push_back(py::reinterpret_borrow(entry)); + } + break; + } + + case Kind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", py::repr(object))); + } + py::dict dict = py::reinterpret_borrow(object); + py::list keys = + py::reinterpret_steal(PyDict_Keys(dict.ptr())); + if (PyList_Sort(keys.ptr())) { + throw std::runtime_error("Dictionary key sort failed."); + } + if (keys.not_equal(node.node_data)) { + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + py::repr(node.node_data), py::repr(object))); + } + for (py::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case Kind::kNamedTuple: { + if (!py::isinstance(object) || + !py::hasattr(object, "_fields")) { + throw std::invalid_argument(absl::StrFormat( + "Expected named tuple, got %s.", py::repr(object))); + } + py::tuple tuple = py::reinterpret_borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, py::repr(object))); + } + if (tuple.get_type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + py::repr(node.node_data), py::repr(object))); + } + for (py::handle entry : tuple) { + agenda.push_back(py::reinterpret_borrow(entry)); + } + break; + } + + case Kind::kCustom: { + auto* registration = CustomNodeRegistry::Lookup(object.get_type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + py::repr(node.custom->type), py::repr(object))); + } + py::tuple out = py::cast(node.custom->to_iterable(object)); + if (out.size() != 2) { + throw std::runtime_error( + "PyTree custom to_iterable function should return a pair"); + } + if (node.node_data.not_equal(out[1])) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + py::repr(node.node_data), py::repr(out[1]), py::repr(object))); + } + int arity = 0; + for (py::handle entry : py::cast(out[0])) { + ++arity; + agenda.push_back(py::reinterpret_borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, py::repr(object))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", py::repr(xs), ToString())); + } + return leaves; +} + +py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, + py::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case Kind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + py::object leaf = py::reinterpret_borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case Kind::kNone: + case Kind::kTuple: + case Kind::kNamedTuple: + case Kind::kList: + case Kind::kDict: + case Kind::kCustom: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + py::tuple tuple(node.arity); + for (int i = node.arity - 1; i >= 0; --i) { + tuple[i] = agenda.back(); + agenda.pop_back(); + } + agenda.push_back(f_node(tuple)); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +py::object PyTreeDef::FromIterableTreeHelper( + py::handle xs, + std::vector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == Kind::kLeaf) { + return py::reinterpret_borrow(xs); + } + py::iterable iterable = py::reinterpret_borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (py::handle x : iterable) { + ys.push_back(py::reinterpret_borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +py::object PyTreeDef::FromIterableTree(py::handle xs) const { + auto it = traversal_.rbegin(); + py::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +std::unique_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + auto out = absl::make_unique(); + for (const Node& n : traversal_) { + if (n.kind == Kind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + const auto& root = traversal_.back(); + const auto& inner_root = inner.traversal_.back(); + // TODO(tomhennigan): This should update all nodes in the traversal. + auto& out_root = out->traversal_.back(); + out_root.num_nodes = (root.num_nodes - root.num_leaves) + + (inner_root.num_nodes * root.num_leaves); + out_root.num_leaves *= inner_root.num_leaves; + return out; +} + +/*static*/ std::unique_ptr PyTreeDef::Tuple( + const std::vector& defs) { + auto out = absl::make_unique(); + for (const PyTreeDef& def : defs) { + absl::c_copy(def.traversal_, std::back_inserter(out->traversal_)); + } + Node node; + node.kind = Kind::kTuple; + node.arity = defs.size(); + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = absl::make_unique(); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string kind; + switch (node.kind) { + case Kind::kLeaf: + agenda.push_back("*"); + continue; + case Kind::kNone: + kind = "None"; + break; + case Kind::kNamedTuple: + kind = "namedtuple"; + break; + case Kind::kTuple: + kind = "tuple"; + break; + case Kind::kList: + kind = "list"; + break; + case Kind::kDict: + kind = "dict"; + break; + case Kind::kCustom: + kind = static_cast(py::str(node.custom->type)); + break; + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ","); + agenda.erase(agenda.end() - node.arity, agenda.end()); + + std::string data; + if (node.node_data) { + data = absl::StrFormat("[%s]", py::str(node.node_data)); + } + + agenda.push_back( + absl::StrFormat("PyTreeDef(%s%s, [%s])", kind, data, children)); + } + + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +void BuildPytreeSubmodule(py::module& m) { + py::module pytree = m.def_submodule("pytree", "Python tree library"); + pytree.def("flatten", &PyTreeDef::Flatten); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + py::class_(m, "PyTreeDef") + .def("unflatten", &PyTreeDef::Unflatten) + .def("flatten_up_to", &PyTreeDef::FlattenUpTo) + .def("compose", &PyTreeDef::Compose) + .def("walk", &PyTreeDef::Walk) + .def("from_iterable_tree", &PyTreeDef::FromIterableTree) + .def("children", &PyTreeDef::Children) + .def_property_readonly("num_leaves", &PyTreeDef::num_leaves) + .def_property_readonly("num_nodes", &PyTreeDef::num_nodes) + .def("__repr__", &PyTreeDef::ToString) + .def("__eq__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }) + .def("__ne__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }) + .def("__hash__", + [](const PyTreeDef& t) { return absl::Hash()(t); }); + + pytree.def("register_node", [](py::object type, py::function to_iterable, + py::function from_iterable) { + return CustomNodeRegistry::Register(type, to_iterable, from_iterable); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h new file mode 100644 index 00000000000..69cd93a7d08 --- /dev/null +++ b/tensorflow/compiler/xla/python/pytree.h @@ -0,0 +1,214 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_ + +// See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation +// about pytree. + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/memory/memory.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace xla { + +// Registry of custom node types. +class CustomNodeRegistry { + public: + struct Registration { + // The Python type object, used to identify the type. + pybind11::object type; + // A function with signature: object -> (iterable, aux_data) + pybind11::function to_iterable; + // A function with signature: (aux_data, iterable) -> object + pybind11::function from_iterable; + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + static void Register(pybind11::object type, pybind11::function to_iterable, + pybind11::function from_iterable); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + static const Registration* Lookup(pybind11::handle type); + + private: + static CustomNodeRegistry* Singleton(); + + struct TypeHash { + size_t operator()(const pybind11::object& t) const { + return pybind11::hash(t); + } + }; + struct TypeEq { + bool operator()(const pybind11::object& a, + const pybind11::object& b) const { + return a.equal(b); + } + }; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + PyTreeDef() = default; + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + static std::pair, std::unique_ptr> + Flatten(pybind11::handle x); + + // Recursive helper used to implement Flatten(). + void FlattenInto(pybind11::handle handle, + std::vector& leaves); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(const pybind11::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + pybind11::list FlattenUpTo(pybind11::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + pybind11::object Unflatten(pybind11::iterable leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. + std::unique_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static std::unique_ptr Tuple(const std::vector& defs); + + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node to each container node. + // TODO(phawkins): use flattening everywhere instead and delete this method. + pybind11::object Walk(const pybind11::function& f_node, + pybind11::handle f_leaf, + pybind11::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + pybind11::object FromIterableTree(pybind11::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + private: + enum class Kind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + }; + + struct Node { + Kind kind = Kind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, contains a sorted list of keys. For a kCustom type, + // contains the auxiliary data returned by the `to_iterable` function. + pybind11::object node_data; + + const CustomNodeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static pybind11::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + pybind11::object FromIterableTreeHelper( + pybind11::handle xs, + std::vector::const_reverse_iterator* it) const; + + // Computes the node kind of a given Python object. + static Kind GetKind(const pybind11::handle& obj, + CustomNodeRegistry::Registration const** custom); + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + std::vector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + return H::combine_contiguous(std::move(h), t.traversal_.data(), + t.traversal_.size()); +} + +void BuildPytreeSubmodule(pybind11::module& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_ diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 4725becdedf..70aeb3f2a86 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -115,6 +115,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "pod_tpu_driver", + srcs = ["pod_tpu_driver.cc"], + deps = [ + ":grpc_tpu_driver", + ":tpu_driver", + ":tpu_driver_proto_cc", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/compiler/xla/pjrt:semaphore", + "//tensorflow/compiler/xla/pjrt:worker_thread", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + tf_grpc_cc_dependency(), + ] + external_deps(), + alwayslink = 1, +) + go_proto_library( name = "tpu_service_go_proto", compatible_with = ["//buildenv/target:gce"], diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index c460cc36f08..30a220ece45 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", + "//tensorflow/compiler/xla/python/tpu_driver:pod_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index e78f04ff980..0602d096aaa 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,8 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, - /*device_kind=*/"Cloud TPU", host_id), + : xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform, + /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} @@ -47,9 +47,9 @@ std::string TpuDevice::DebugString() const { coords_[0], coords_[1], coords_[2], core_on_chip_); } -xla::StatusOr>> +xla::StatusOr>> TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) { - std::vector> devices; + std::vector> devices; for (const auto& chip : system_info.tpu_chip()) { auto& coord = chip.chip_coord(); std::array coords_array = {coord.x(), coord.y(), coord.z()}; @@ -78,7 +78,7 @@ StatusOr> PyTpuClient::Get( tpu_driver::SystemInfo system_info; client->QuerySystemInfo(&system_info); - TF_ASSIGN_OR_RETURN(std::vector> devices, + TF_ASSIGN_OR_RETURN(std::vector> devices, TpuDevice::GetTpuDevices(system_info)); return std::make_shared(kTpuPlatform, std::move(client), @@ -88,13 +88,13 @@ StatusOr> PyTpuClient::Get( PyTpuClient::PyTpuClient(std::string platform_name, std::unique_ptr driver, - std::vector> devices, + std::vector> devices, int host_id) : platform_name_(std::move(platform_name)), driver_(std::move(driver)), devices_(std::move(devices)), host_id_(host_id) { - for (const std::shared_ptr& device : devices_) { + for (const std::shared_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); @@ -173,7 +173,7 @@ static Status CheckDataType(xla::PrimitiveType dtype) { StatusOr> PyTpuBuffer::FromLiterals( std::vector leaves, const Shape& tuple_shape, std::shared_ptr leaves_references, - std::shared_ptr client, std::shared_ptr device) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals"); VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString() << " device: " << device->DebugString(); @@ -229,7 +229,7 @@ StatusOr> PyTpuBuffer::FromLiterals( /* static */ StatusOr> PyTpuBuffer::MakeTuple( absl::Span buffers, std::shared_ptr client, - std::shared_ptr device) { + std::shared_ptr device) { std::vector child_shapes; std::vector> child_device_buffers; std::vector child_handle_ptrs; @@ -388,7 +388,7 @@ PyTpuBuffer::DestructureTuple() { } StatusOr> PyTpuBuffer::CopyToDevice( - std::shared_ptr dst_device) { + std::shared_ptr dst_device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice"); if (on_host_shape_.IsTuple()) { return Unimplemented("CopyToDevice for tuples is not supported."); @@ -433,7 +433,7 @@ Status PyTpuBuffer::BlockHostUntilReady() { /* static */ StatusOr> PyTpuBuffer::AllocateBuffer( const Shape& shape, std::shared_ptr client, - std::shared_ptr device) { + std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer"); VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString() << " device: " << device->DebugString(); @@ -465,7 +465,7 @@ StatusOr> PyTpuBuffer::AllocateBuffer( /*static*/ StatusOr> PyTpuBuffer::CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, std::shared_ptr device) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer"); VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: " << non_tuple_shape.DebugString() @@ -493,8 +493,8 @@ StatusOr> PyTpuBuffer::CreateBuffer( std::vector>(), client); } -static std::shared_ptr LookupDevice(const PyTpuClient& client, - int device_id) { +static std::shared_ptr LookupDevice(const PyTpuClient& client, + int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -516,7 +516,7 @@ PyTpuExecutable::PyTpuExecutable( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = device_assignment_(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + std::shared_ptr device = LookupDevice(*client_, device_id); if (device->host_id() != client_->host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; @@ -541,7 +541,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( absl::Span this_core_arguments, int replica, int partition, const RunId& run_id) { const int device_id = device_assignment_(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + std::shared_ptr device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute"); VLOG(3) << "Replica " << replica << ", partition " << partition @@ -588,7 +588,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( static const absl::Duration kWarnExecutionDelay = absl::Seconds(10); // Delay before terminating a stalled execute call. -static const absl::Duration kMaxExecutionDelay = absl::Seconds(120); +static const absl::Duration kMaxExecutionDelay = absl::Minutes(60); Status WaitForExecuteEvent(tpu_driver::Event* event) { absl::optional opt_status; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 4c45df181db..c2a424677fd 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -38,7 +38,7 @@ namespace xla { constexpr char kTpuPlatform[] = "tpu"; -class TpuDevice : public Device { +class TpuDevice : public PjRtDevice { public: TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip); @@ -48,8 +48,8 @@ class TpuDevice : public Device { std::string DebugString() const override; - static xla::StatusOr>> GetTpuDevices( - const tpu_driver::SystemInfo& system_info); + static xla::StatusOr>> + GetTpuDevices(const tpu_driver::SystemInfo& system_info); private: const std::array coords_; @@ -66,7 +66,7 @@ class PyTpuClient { explicit PyTpuClient(std::string platform_name, std::unique_ptr driver, - std::vector> devices, + std::vector> devices, int host_id); virtual ~PyTpuClient() = default; @@ -83,11 +83,11 @@ class PyTpuClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() { return devices_; } - const std::vector>& local_devices() { + const std::vector>& devices() { return devices_; } + const std::vector>& local_devices() { return local_devices_; } - const std::map>& id_to_device() const { + const std::map>& id_to_device() const { return id_to_device_; } int host_id() const { return host_id_; } @@ -110,11 +110,11 @@ class PyTpuClient { std::unique_ptr driver_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> devices_; // Maps Device::id() to the corresponding Device. Includes all devices. - std::map> id_to_device_; + std::map> id_to_device_; // Local devices indexed by local device ordinal. - std::vector> local_devices_; + std::vector> local_devices_; int host_id_; // A thread pool for scheduling core executions in parallel. @@ -128,7 +128,7 @@ struct TpuSharedBuffer final { TpuSharedBuffer(tpu_driver::TpuDriver* driver, std::unique_ptr handle, std::vector> wait_for_use, - std::shared_ptr src_device) + std::shared_ptr src_device) : driver(driver), device(std::move(src_device)), handle(std::move(handle)), @@ -143,7 +143,7 @@ struct TpuSharedBuffer final { } tpu_driver::TpuDriver* const driver; - const std::shared_ptr device; + const std::shared_ptr device; std::unique_ptr handle; std::vector> wait_for_use; @@ -162,12 +162,12 @@ class PyTpuBuffer { static StatusOr> FromLiterals( std::vector leaves_literals, const Shape& tuple_shape, std::shared_ptr leaves_reference, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); // Supports nested tuple creation. static StatusOr> MakeTuple( absl::Span buffers, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); PyTpuBuffer() = delete; PyTpuBuffer(Shape on_host_shape, @@ -181,7 +181,7 @@ class PyTpuBuffer { PyTpuBuffer& operator=(PyTpuBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } - std::shared_ptr device() const { return device_; } + std::shared_ptr device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } std::shared_ptr client() const { return client_; } @@ -210,7 +210,7 @@ class PyTpuBuffer { // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer // object holding the context to the target device buffer. StatusOr> CopyToDevice( - std::shared_ptr dst_device); + std::shared_ptr dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -220,7 +220,7 @@ class PyTpuBuffer { // tuple, the returned buffer corresponds to the root tuple buffer. static StatusOr> AllocateBuffer( const Shape& shape, std::shared_ptr client, - std::shared_ptr device); + std::shared_ptr device); private: // Initializes a just allocated device buffer. The returned event will be @@ -231,11 +231,11 @@ class PyTpuBuffer { static StatusOr> CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); const std::shared_ptr client_; const Shape on_host_shape_; - const std::shared_ptr device_; + const std::shared_ptr device_; // If this is a tuple, `device_buffer_` stores the tuple buffer and // `child_buffers_` stores the child buffers; else, `device_buffer_` stores @@ -302,7 +302,7 @@ class PyTpuExecutable { return local_logical_device_ids_; } - const std::vector>& local_devices() const { + const std::vector>& local_devices() const { return local_devices_; } @@ -350,7 +350,7 @@ class PyTpuExecutable { // assigned. // shared_ptrs instead of unique_ptrs to play well with the Python bindings // (see xla.cc). - std::vector> local_devices_; + std::vector> local_devices_; xla::Shape result_shape_; }; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 9a794b79c5c..5d526b51899 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -40,11 +40,12 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("host_id", &PyTpuClient::host_id) .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas, int num_partitions) - -> StatusOr>>> { + -> StatusOr< + std::vector>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, num_partitions)); - std::vector>> result; + std::vector>> result; result.resize(num_replicas); for (int r = 0; r < num_replicas; ++r) { result[r].resize(num_partitions); @@ -60,11 +61,11 @@ PYBIND11_MODULE(tpu_client_extension, m) { // TODO(skye): delete after all callers can handle 2D output .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas) - -> StatusOr>> { + -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, /*num_partitions=*/1)); - std::vector> result; + std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); auto iter = client->id_to_device().find(device_id); @@ -96,7 +97,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def( "buffer_from_pyval", [](std::shared_ptr client, - const pybind11::object& argument, std::shared_ptr device, + const pybind11::object& argument, + std::shared_ptr device, bool force_copy) -> StatusOr> { if (device == nullptr) { TF_RET_CHECK(!client->local_devices().empty()); @@ -145,7 +147,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_(m, "PyTpuBuffer") .def_property_readonly("client", &PyTpuBuffer::client) .def("copy_to_device", - [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { + [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; @@ -202,7 +204,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def_property_readonly("traceback", [](PyTpuExecutable*) { return py::none(); }); - py::class_>(m, "TpuDevice") + py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) .def_property_readonly("core_on_chip", &TpuDevice::core_on_chip) .def("__repr__", [](const TpuDevice& device) { diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc new file mode 100644 index 00000000000..ac54df39895 --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -0,0 +1,806 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" +#include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" +#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace tpu_driver { +namespace { + +using xla::Status; +using xla::WorkerThread; + +const char kPodTpuDriverPrefix[] = "grpc+pod://"; + +class PodTpuDriver; + +class PodEvent : public Event { + public: + explicit PodEvent(PodTpuDriver* driver, int64_t operation_id) + : driver_(driver), operation_id_(operation_id) {} + int64_t operation_id() const { return operation_id_; } + + xla::Status Await() override; + + absl::optional AwaitWithTimeout( + absl::Duration duration) override; + + void AddCallback(std::function callback) override; + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; +}; + +class CombinedEvent : public PodEvent { + public: + explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, + std::vector> events) + : PodEvent(driver, operation_id), events_(events) {} + + xla::Status Await() override { + for (auto& event : events_) { + TF_RETURN_IF_ERROR(event->Await()); + } + return Status::OK(); + } + + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + // TODO(frankchn): This might extend the timeout. + for (auto& event : events_) { + auto status = event->AwaitWithTimeout(duration); + if (status == absl::nullopt) { + return absl::nullopt; + } else { + TF_RETURN_IF_ERROR(status.value()); + } + } + return Status::OK(); + } + + void AddCallback(std::function callback) override { + // TODO(frankchn): This may return before every event is done. + events_[0]->AddCallback(std::move(callback)); + } + + private: + std::vector> events_; +}; + +class PodBufferHandle : public BufferHandle { + public: + explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id, + int64_t size_in_bytes, + absl::optional shape, + int64_t core_id) + : driver_(driver), + operation_id_(operation_id), + size_in_bytes_(size_in_bytes), + shape_(shape), + event_(std::make_shared(driver_, operation_id_)), + core_id_(core_id) {} + + std::shared_ptr OnReady() override { return event_; } + int64_t size_in_bytes() override { return size_in_bytes_; } + absl::optional shape() override { return shape_; } + + int64_t operation_id() const { return operation_id_; } + int64_t core_id() const { return core_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + const int64_t size_in_bytes_; + const absl::optional shape_; + std::shared_ptr event_; + const int64_t core_id_; +}; + +class PodCompiledProgramHandle : public CompiledProgramHandle { + public: + explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id) + : driver_(driver), + operation_id_(operation_id), + event_(std::make_shared(driver_, operation_id_)) {} + + std::shared_ptr OnReady() override { return event_; } + + xla::Status program_shape(xla::ProgramShapeProto* program_shape) override; + + int64_t operation_id() const { return operation_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + std::shared_ptr event_; +}; + +class PodLoadedProgramHandle : public LoadedProgramHandle { + public: + explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id, + int64_t core_id) + : driver_(driver), + operation_id_(operation_id), + core_id_(core_id), + event_(std::make_shared(driver_, operation_id_)) {} + + std::shared_ptr OnReady() override { return event_; } + + int64_t operation_id() const { return operation_id_; } + int64_t core_id() const { return core_id_; } + + private: + PodTpuDriver* driver_; + const int64_t operation_id_; + const int64_t core_id_; + std::shared_ptr event_; +}; + +struct EventInFlight { + std::shared_ptr underlying_event; + std::function(void)> create_fn; + + absl::flat_hash_set incomplete_deps; + std::vector> callbacks; +}; + +class PodTpuDriver : public TpuDriver { + public: + explicit PodTpuDriver(const TpuDriverConfig& config, + std::shared_ptr<::grpc::ChannelCredentials> creds) + : config_(config), + creds_(creds), + event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") { + std::vector workers = absl::StrSplit( + absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ','); + for (const auto& worker : workers) { + TpuDriverConfig worker_config(config_); + *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker); + drivers_.push_back( + CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie()); + } + + int cumulative_core_id = 0; + absl::flat_hash_set> processed_chips; + + for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { + SystemInfo driver_info; + drivers_[driver_num]->QuerySystemInfo(&driver_info); + + for (const auto& tpu_chip : driver_info.tpu_chip()) { + std::tuple coord{tpu_chip.chip_coord().x(), + tpu_chip.chip_coord().y(), + tpu_chip.chip_coord().z()}; + // We only want to add chips that we have not seen before if we are in a + // TPU pod slice, or we are only seeing local cores (e.g. we are + // connected to individual TPUs or we are in a test environment). + if (!processed_chips.contains(coord) || + driver_info.core_count() == driver_info.local_core_size()) { + *(pod_info_.add_tpu_chip()) = tpu_chip; + processed_chips.insert(coord); + } + } + + *(pod_info_.mutable_cpu()) = driver_info.cpu(); + } + + // Process all the unique chips that we have seen. + for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { + for (auto& tpu_core : *tpu_chip.mutable_core()) { + int current_core = cumulative_core_id++; + + core_to_driver_.push_back(drivers_[tpu_chip.host_id()].get()); + core_to_driver_id_.push_back(tpu_chip.host_id()); + core_to_driver_core_.push_back(tpu_core.id()); + + tpu_core.set_id(current_core); + tpu_core.set_core_on_host_index(current_core); + *(pod_info_.add_local_core()) = tpu_core; + } + + // We are setting host_id to zero because we want this to look like one + // host with many cores from the perspective of tpu_client.cc. + tpu_chip.set_host_id(0); + } + + pod_info_.set_chip_count(pod_info_.tpu_chip_size()); + pod_info_.set_core_count(pod_info_.local_core_size()); + + // We want this to look like one host with many TPU chips/cores connected. + pod_info_.set_host_count(1); + pod_info_.set_host_id(0); + } + + ~PodTpuDriver() override { + // TODO(frankchn): Unload all handles, and wait for all events to finish. + } + + void QuerySystemInfo(SystemInfo* system_info) override { + *system_info = pod_info_; + } + + xla::Status Reset() override { + for (auto& driver : drivers_) { + TF_RETURN_IF_ERROR(driver->Reset()); + } + return xla::Status::OK(); + } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, int64_t num_bytes, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, core_id, region, num_bytes, operation_id]() { + absl::MutexLock l(&mu_); + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], + region, num_bytes, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, num_bytes, + absl::nullopt, core_id); + } + + std::unique_ptr Allocate( + int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, core_id, region, shape, operation_id]() { + absl::MutexLock l(&mu_); + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], + region, shape, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique( + this, operation_id, ComputeBytesFromShape(shape), shape, core_id); + } + + std::unique_ptr AllocateTuple( + int32_t core_id, MemoryRegion region, + absl::Span children, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + std::vector children_ids; + for (int i = 0; i < children.size(); ++i) { + auto child_op_id = + static_cast(children[i])->operation_id(); + deps.insert(child_op_id); + children_ids.push_back(child_op_id); + } + + ScheduleRequest( + operation_id, + [this, core_id, region, children_ids, operation_id]() { + absl::MutexLock l(&mu_); + + std::vector child_buffers; + child_buffers.reserve(children_ids.size()); + for (int i = 0; i < children_ids.size(); ++i) { + child_buffers.push_back(underlying_buffers_[children_ids[i]].get()); + } + + underlying_buffers_.insert( + {operation_id, + core_to_driver_[core_id]->AllocateTuple( + core_to_driver_core_[core_id], region, child_buffers, {})}); + return underlying_buffers_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, 0, + absl::nullopt, core_id); + } + + std::shared_ptr Deallocate( + std::unique_ptr handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(handle.get())->operation_id()); + + auto op_id = static_cast(handle.get())->operation_id(); + auto core_id = static_cast(handle.get())->core_id(); + + ScheduleRequest( + operation_id, + [this, op_id, core_id]() { + absl::MutexLock l(&mu_); + auto buf_iter = underlying_buffers_.find(op_id); + auto underlying_hn = std::move(buf_iter->second); + underlying_buffers_.erase(buf_iter); + + return core_to_driver_[core_id]->Deallocate(std::move(underlying_hn), + {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferToDevice( + const void* src, BufferHandle* dst, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(dst)->operation_id()); + + auto op_id = static_cast(dst)->operation_id(); + auto core_id = static_cast(dst)->core_id(); + + ScheduleRequest( + operation_id, + [this, src, op_id, core_id]() { + absl::MutexLock l(&mu_); + auto buf_iter = underlying_buffers_.find(op_id); + return core_to_driver_[core_id]->TransferToDevice( + src, buf_iter->second.get(), {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferFromDevice( + const BufferHandle* src, void* dst, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + + auto op_id = static_cast(src)->operation_id(); + auto core_id = static_cast(src)->core_id(); + + ScheduleRequest( + operation_id, + [this, dst, op_id, core_id]() { + absl::MutexLock l(&mu_); + auto buf_iter = underlying_buffers_.find(op_id); + return core_to_driver_[core_id]->TransferFromDevice( + buf_iter->second.get(), dst, {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr TransferFromDeviceToDevice( + const BufferHandle* src, BufferHandle* dst, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + deps.insert(static_cast(dst)->operation_id()); + + auto src_op_id = static_cast(src)->operation_id(); + auto dst_op_id = static_cast(dst)->operation_id(); + auto core_id = static_cast(dst)->core_id(); + + ScheduleRequest( + operation_id, + [this, src_op_id, dst_op_id, core_id]() { + absl::MutexLock l(&mu_); + auto src_iter = underlying_buffers_.find(src_op_id); + auto dst_iter = underlying_buffers_.find(dst_op_id); + return core_to_driver_[core_id]->TransferFromDeviceToDevice( + src_iter->second.get(), dst_iter->second.get(), {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::unique_ptr CompileProgram( + const xla::HloProto& source, int32_t num_replicas, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + + ScheduleRequest( + operation_id, + [this, operation_id, source, num_replicas]() { + absl::MutexLock l(&mu_); + auto cph_iterator = + underlying_cph_ + .insert( + {operation_id, + std::vector>()}) + .first; + + std::vector> collected_events; + for (int i = 0; i < drivers_.size(); ++i) { + auto current_cph = + drivers_[i]->CompileProgram(source, num_replicas, {}); + cph_iterator->second.push_back(std::move(current_cph)); + collected_events.push_back(cph_iterator->second[i]->OnReady()); + } + return std::make_shared(this, operation_id, + collected_events); + }, + deps); + + return absl::make_unique(this, operation_id); + } + + std::unique_ptr LoadProgram( + int32_t core_id, const CompiledProgramHandle* handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert( + static_cast(handle)->operation_id()); + auto cph_op_id = + static_cast(handle)->operation_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, cph_op_id, core_id]() { + absl::MutexLock l(&mu_); + auto cph_iter = underlying_cph_.find(cph_op_id); + + underlying_lph_.insert( + {operation_id, + core_to_driver_[core_id]->LoadProgram( + core_to_driver_core_[core_id], + cph_iter->second[core_to_driver_id_[core_id]].get(), {})}); + + return underlying_lph_[operation_id]->OnReady(); + }, + deps); + + return absl::make_unique(this, operation_id, + core_id); + } + + std::shared_ptr UnloadProgram( + std::unique_ptr handle, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert( + static_cast(handle.get())->operation_id()); + auto op_id = + static_cast(handle.get())->operation_id(); + auto core_id = + static_cast(handle.get())->core_id(); + + ScheduleRequest( + operation_id, + [this, op_id, core_id]() { + absl::MutexLock l(&mu_); + + auto lph_iter = underlying_lph_.find(op_id); + auto event = core_to_driver_[core_id]->UnloadProgram( + std::move(lph_iter->second), {}); + underlying_lph_.erase(lph_iter); + + return event; + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::shared_ptr ExecuteProgram( + LoadedProgramHandle* program, absl::Span inputs, + absl::Span outputs, + const xla::DeviceAssignmentProto& device_assignment, + absl::Span wait_for) override { + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(program)->operation_id()); + + auto op_id = static_cast(program)->operation_id(); + auto core_id = static_cast(program)->core_id(); + + std::vector input_op_ids; + std::vector output_op_ids; + + for (auto* input : inputs) { + auto input_dep = + static_cast(input)->operation_id(); + input_op_ids.push_back(input_dep); + deps.insert(input_dep); + } + for (auto* output : outputs) { + auto output_dep = + static_cast(output)->operation_id(); + output_op_ids.push_back(output_dep); + deps.insert(output_dep); + } + + ScheduleRequest( + operation_id, + [this, core_id, op_id, input_op_ids, output_op_ids, + device_assignment]() { + absl::MutexLock l(&mu_); + + std::vector underlying_inputs; + std::vector underlying_outputs; + + underlying_inputs.reserve(input_op_ids.size()); + for (auto input_op_id : input_op_ids) { + underlying_inputs.push_back(underlying_buffers_[input_op_id].get()); + } + underlying_outputs.reserve(output_op_ids.size()); + for (auto output_op_id : output_op_ids) { + underlying_outputs.push_back( + underlying_buffers_[output_op_id].get()); + } + + LoadedProgramHandle* handle = underlying_lph_[op_id].get(); + return core_to_driver_[core_id]->ExecuteProgram( + handle, underlying_inputs, underlying_outputs, device_assignment, + {}); + }, + deps); + + return std::make_shared(this, operation_id); + } + + std::unique_ptr GetLinearizer() override { + return drivers_[0]->GetLinearizer(); + } + + // Helper methods for Event scheduling + + absl::optional WaitForEvent(int64_t event_id, + absl::Duration duration) { + std::shared_ptr underlying_event; + + { + absl::MutexLock l(&event_mu_); + auto event = events_.find(event_id); + + if (event == events_.end()) { + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + return Status::OK(); + } else { + return event_status->second; + } + } + + auto done = [this, event_id]() { + event_mu_.AssertHeld(); + return events_[event_id].underlying_event != nullptr; + }; + + auto status = + event_mu_.AwaitWithTimeout(absl::Condition(&done), duration); + if (!status) { + return absl::nullopt; + } + underlying_event = events_[event_id].underlying_event; + } + + // Wait for the underlying event without holding on to the event_lock_, or + // else incoming events will not be processed. + return underlying_event->AwaitWithTimeout(duration); + } + + void AddCallbackForEvent(int64_t event_id, std::function fn) { + absl::MutexLock l(&event_mu_); + auto event = events_.find(event_id); + + if (event == events_.end()) { + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + fn(Status::OK()); + } else { + fn(event_status->second); + } + } + + if (event->second.underlying_event != nullptr) { + event->second.underlying_event->AddCallback(fn); + } else { + event->second.callbacks.push_back(std::move(fn)); + } + } + + xla::Status GetCompiledProgramShape(int64_t op_id, + xla::ProgramShapeProto* program_shape) { + absl::MutexLock l(&mu_); + + auto done = [this, op_id]() { + mu_.AssertHeld(); + return underlying_cph_.contains(op_id); + }; + mu_.Await(absl::Condition(&done)); + + return underlying_cph_[op_id][0]->program_shape(program_shape); + } + + private: + const TpuDriverConfig& config_; + std::shared_ptr<::grpc::ChannelCredentials> creds_; + + std::vector> drivers_; + std::vector core_to_driver_id_; + std::vector core_to_driver_; + std::vector core_to_driver_core_; + SystemInfo pod_info_; + + absl::Mutex mu_; + absl::Mutex event_mu_; + + absl::flat_hash_map> + underlying_buffers_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map>> + underlying_cph_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map> + underlying_lph_ ABSL_GUARDED_BY(mu_); + + absl::btree_map events_ ABSL_GUARDED_BY(event_mu_); + absl::flat_hash_map abnormal_event_status_ + ABSL_GUARDED_BY(event_mu_); + + std::atomic operation_id_counter_{0}; + + WorkerThread event_thread_; + + int64_t GetOperationId() { return operation_id_counter_++; } + + absl::flat_hash_set GetDependencyOperationIds( + absl::Span wait_for) { + absl::flat_hash_set deps; + for (auto* event : wait_for) { + deps.insert(static_cast(event)->operation_id()); + } + return deps; + } + + // EventCompleted is executed on the event_thread_ worker thread. We want + // to propagate the fact that the event is completed to any subsequent events + // that might depend on this event. + void EventCompleted(int64_t event_id, Status status) { + absl::MutexLock l(&event_mu_); + + absl::btree_map::iterator curr_event; + if (!status.ok()) abnormal_event_status_.insert({event_id, status}); + curr_event = events_.find(event_id); + + DCHECK(curr_event->second.callbacks.empty()); + DCHECK(curr_event->second.incomplete_deps.empty()); + + for (auto& event : events_) { + event.second.incomplete_deps.erase(event_id); + // The if statement conditions on both + // - all previous events have completed (incomplete_deps.empty()) + // - the op creating this event has not been called yet + // (event.second.create_fn != nullptr) + // We call the create_fn that creates the event and adds any relevant + // callbacks to the actual event, before setting create_fn to nullptr + // to indicate that it has already been called + if (event.second.incomplete_deps.empty() && + event.second.create_fn != nullptr) { + // We were the last unfilled dependency, all other dependencies are + // filled. We can now fire the create function. + event.second.underlying_event = event.second.create_fn(); + for (auto& fn : event.second.callbacks) { + event.second.underlying_event->AddCallback(std::move(fn)); + } + event.second.callbacks.clear(); + event.second.create_fn = nullptr; + } + } + + // We erase the current event to signal that it has finished. + events_.erase(curr_event); + } + + void ScheduleRequest(int64_t operation_id, + std::function(void)> fn, + const absl::flat_hash_set& deps) { + absl::MutexLock l(&event_mu_); + absl::btree_map::iterator event; + absl::flat_hash_set incomplete_deps; + + event = events_.insert({operation_id, {}}).first; + for (const auto& dep : deps) { + if (events_.count(dep) > 0) incomplete_deps.insert(dep); + } + + if (incomplete_deps.empty()) { + // All dependencies have been fulfilled, we execute the request + // immediately and add a callback to inform our event fulfilled thread + // when it is done. + event->second.create_fn = nullptr; + event->second.underlying_event = fn(); + event->second.underlying_event->AddCallback( + [this, operation_id](Status status) { + event_thread_.Schedule([this, operation_id, status]() { + EventCompleted(operation_id, status); + }); + }); + } else { + // There are some dependencies that are not yet fulfilled. We attach + // the request to the event, and will execute it in the EventFulfilled + // worker thread when all its dependencies are fulfilled. + event->second.create_fn = std::move(fn); + event->second.incomplete_deps = std::move(incomplete_deps); + event->second.callbacks.push_back([this, operation_id](Status status) { + event_thread_.Schedule([this, operation_id, status]() { + EventCompleted(operation_id, status); + }); + }); + } + } +}; + +xla::Status PodEvent::Await() { + return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value(); +} + +absl::optional PodEvent::AwaitWithTimeout( + absl::Duration duration) { + return driver_->WaitForEvent(operation_id_, duration); +} + +void PodEvent::AddCallback(std::function callback) { + driver_->AddCallbackForEvent(operation_id_, std::move(callback)); +} + +xla::StatusOr> CreatePodTpuDriver( + const TpuDriverConfig& config, + std::shared_ptr<::grpc::ChannelCredentials> creds) { + return std::unique_ptr(new PodTpuDriver(config, creds)); +} + +xla::Status PodCompiledProgramHandle::program_shape( + xla::ProgramShapeProto* program_shape) { + return driver_->GetCompiledProgramShape(operation_id(), program_shape); +} + +} // namespace + +REGISTER_TPU_DRIVER(kPodTpuDriverPrefix, + [](const TpuDriverConfig& config) + -> xla::StatusOr> { + return CreatePodTpuDriver( + config, + ::grpc::InsecureChannelCredentials()); // NOLINT + }); + +} // namespace tpu_driver diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 9590c5d57c3..06605660b63 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -44,11 +44,13 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" +#include "tensorflow/compiler/xla/python/jax_jit.h" #include "tensorflow/compiler/xla/python/ops.h" #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" #include "tensorflow/compiler/xla/python/py_buffer.h" #include "tensorflow/compiler/xla/python/py_executable.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" +#include "tensorflow/compiler/xla/python/pytree.h" #include "tensorflow/compiler/xla/python/traceback.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -169,13 +171,13 @@ class TraceMeWrapper : public tensorflow::profiler::TraceMeWrapper { void BuildProfilerSubmodule(py::module* m) { py::module profiler = m->def_submodule("profiler", "TensorFlow profiler integration"); - py::class_> + py::class_> profiler_server_class(profiler, "ProfilerServer"); profiler.def( "start_server", - [](int port) -> std::unique_ptr { - auto server = absl::make_unique(); + [](int port) -> std::unique_ptr { + auto server = absl::make_unique(); server->StartProfilerServer(port); return server; }, @@ -437,26 +439,26 @@ PYBIND11_MODULE(xla_extension, m) { device_assignment); }); - py::class_>( + py::class_>( m, "Device", "A descriptor of an available device.\n\nSubclasses are used to " "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " "have additional properties specific to that device type.") .def_property_readonly( - "id", &Device::id, + "id", &PjRtDevice::id, "Integer ID of this device.\n\nUnique across all available devices " "of this type, including remote devices on multi-host platforms.") - .def_property_readonly("host_id", &Device::host_id, + .def_property_readonly("host_id", &PjRtDevice::host_id, "Integer ID of this device's host.\n\n" "This is always 0 except on multi-host platforms.") - .def_property_readonly("platform", &Device::platform_name) - .def_property_readonly("device_kind", &Device::device_kind) + .def_property_readonly("platform", &PjRtDevice::platform_name) + .def_property_readonly("device_kind", &PjRtDevice::device_kind) .def_property_readonly( "client", - [](const ClientAndPtr& device) { return device.client; }) - .def("__str__", &Device::DebugString) + [](const ClientAndPtr& device) { return device.client; }) + .def("__str__", &PjRtDevice::DebugString) .def("transfer_to_infeed", - [](const Device& device, const LiteralSlice& literal) { + [](const PjRtDevice& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, @@ -466,7 +468,8 @@ PYBIND11_MODULE(xla_extension, m) { }) .def( "transfer_from_outfeed", - [](const Device& device, const Shape& shape) -> StatusOr { + [](const PjRtDevice& device, + const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); std::shared_ptr literal_shared; { @@ -490,12 +493,12 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal_shared)); }); - py::class_>(m, "CpuDevice") + py::class_>(m, "CpuDevice") .def("__repr__", [](const CpuDevice& device) { return absl::StrFormat("CpuDevice(id=%i)", device.id()); }); - py::class_>(m, "GpuDevice") + py::class_>(m, "GpuDevice") .def("__repr__", [](const GpuDevice& device) { return absl::StrFormat("GpuDevice(id=%i)", device.id()); }); @@ -654,7 +657,7 @@ PYBIND11_MODULE(xla_extension, m) { PyTypeObject* buffer_type = reinterpret_cast(buffer.ptr()); buffer_type->tp_as_buffer = PyBuffer::BufferProtocol(); - py::class_> executable( + py::class_> executable( m, "Executable"); executable.def_property_readonly("client", &PyExecutable::client) .def("local_logical_device_ids", &PyExecutable::local_logical_device_ids) @@ -737,7 +740,7 @@ PYBIND11_MODULE(xla_extension, m) { .def(py::init([](const py::bytes& serialized_hlo_module_proto) -> std::unique_ptr { HloModuleProto proto; - proto.ParseFromString(serialized_hlo_module_proto); + proto.ParseFromString(std::string(serialized_hlo_module_proto)); return absl::make_unique(proto); })) .def("get_hlo_module", &GetHloModule) @@ -897,6 +900,8 @@ PYBIND11_MODULE(xla_extension, m) { BuildOpsSubmodule(&m); BuildProfilerSubmodule(&m); BuildOutfeedReceiverSubmodule(&m); + BuildPytreeSubmodule(m); + BuildJaxjitSubmodule(m); py::class_> diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 38c55c6fe5d..da548ca1f0d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -113,16 +113,17 @@ def _get_local_backends(): _local_backends = collections.OrderedDict() for name, factory in _local_backend_factories.items(): - logging.vlog(2, "Initializing backend '%s'" % name) + logging.vlog(1, "Initializing backend '%s'" % name) try: backend = factory() - except RuntimeError: + except RuntimeError as err: if name == 'cpu': # We always expect CPU to initialize successfully. raise else: # If the backend isn't built into the binary, or if it has no devices, # we expect a RuntimeError. + logging.vlog(1, "Error initializing backend '%s': %s" % (name, err)) continue _local_backends[name] = backend return _local_backends diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 49431b19a69..6874d00445c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -83,6 +83,7 @@ cc_library( deps = [ ":bfloat16_support", ":hlo", + ":hlo_dataflow_analysis", ":hlo_pass", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1431,6 +1432,7 @@ cc_library( ":hlo_live_range", ":hlo_ordering", ":hlo_proto_cc", + ":memory_space_assignment_repacking", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1683,6 +1685,7 @@ cc_library( hdrs = ["multi_output_fusion.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":hlo_dce", ":hlo_pass", ":hlo_reachability", @@ -1700,7 +1703,10 @@ cc_library( cc_library( name = "hlo_creation_utils", srcs = ["hlo_creation_utils.cc"], - hdrs = ["hlo_creation_utils.h"], + hdrs = [ + "hlo_creation_utils.h", + "//tensorflow/compiler/xla:literal_util", + ], deps = [ ":hlo", ":hlo_module_config", @@ -1816,6 +1822,21 @@ cc_library( ], ) +cc_library( + name = "comparison_expander", + srcs = ["comparison_expander.cc"], + hdrs = ["comparison_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client/lib:comparators", + ], +) + cc_library( name = "scatter_expander", srcs = ["scatter_expander.cc"], @@ -1824,6 +1845,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -2259,6 +2281,7 @@ tf_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", + ":hlo_query", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", @@ -2319,9 +2342,11 @@ cc_library( ":call_inliner", ":hlo", ":hlo_casting_utils", + ":hlo_cse", ":hlo_dce", ":hlo_pass", ":hlo_pass_pipeline", + ":hlo_verifier", ":tuple_simplifier", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2665,6 +2690,7 @@ cc_library( ":hlo_casting_utils", ":hlo_dce", ":hlo_pass", + ":shape_inference", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -2688,7 +2714,6 @@ xla_test( ":dynamic_padder", ":hlo", ":hlo_dce", - ":hlo_get_dimension_size_rewriter", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:debug_options_flags", @@ -3407,6 +3432,35 @@ cc_library( ], ) +cc_library( + name = "memory_space_assignment_repacking", + hdrs = ["memory_space_assignment_repacking.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + ], +) + +cc_library( + name = "memory_space_assignment_best_fit_repacker", + srcs = ["memory_space_assignment_best_fit_repacker.cc"], + hdrs = ["memory_space_assignment_best_fit_repacker.h"], + deps = [ + ":heap_simulator", + ":memory_space_assignment_repacking", + ], +) + +tf_cc_test( + name = "memory_space_assignment_best_fit_repacker_test", + srcs = ["memory_space_assignment_best_fit_repacker_test.cc"], + deps = [ + ":memory_space_assignment_best_fit_repacker", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "memory_space_assignment", srcs = ["memory_space_assignment.cc"], @@ -3414,6 +3468,7 @@ cc_library( deps = [ ":heap_simulator", ":hlo_cost_analysis", + ":memory_space_assignment_repacking", ":memory_space_assignment_utils", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core/lib/math:math_util", @@ -3968,42 +4023,6 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_get_dimension_size_rewriter", - srcs = ["hlo_get_dimension_size_rewriter.cc"], - hdrs = ["hlo_get_dimension_size_rewriter.h"], - deps = [ - ":dynamic_dimension_inference", - ":hlo", - ":hlo_pass", - ":shape_inference", - "//tensorflow/compiler/xla:literal_util", - "@com_google_absl//absl/algorithm:container", - ], -) - -tf_cc_test( - name = "hlo_get_dimension_size_rewriter_test", - srcs = ["hlo_get_dimension_size_rewriter_test.cc"], - deps = [ - ":hlo", - ":hlo_get_dimension_size_rewriter", - ":hlo_matchers", - ":hlo_parser", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "maybe_owning_device_memory", srcs = [ diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 0b588048e4a..4e7bd85e557 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( HloInstruction* inst; HloInstruction* user; int64 index; - std::tie (inst, user, index) = operands.back(); + std::tie(inst, user, index) = operands.back(); operands.pop_back(); // Skip the op types that are not commutative with multiply. @@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && (ShapeUtil::ElementIsIntegral(add->shape()) || - IsAllFpConstantPowerOf2(c))) { + options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, @@ -1236,6 +1236,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } + if (options_.is_layout_sensitive()) { + return Status::OK(); + } + // Check if we can merge "adjacent" slice operands which take slices from the // same other op. For simplicity we only merge unstrided slices. int64 concatenate_dimension = concatenate->concatenate_dimension(); @@ -1296,7 +1300,15 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( auto replacement = computation_->AddInstruction(concatenate->CloneWithNewOperands( concatenate->shape(), new_operands)); - ReplaceInstructionIfSameShape(concatenate, replacement); + + // Recurse to handle multiple disjoint sequence of inputs. The + // logic above merge only 1 sequential series of + // inputs. Otherwise, it can lead to the FixPass optimization + // hitting its threshold. + if (ReplaceInstructionIfSameShape(concatenate, replacement)) { + return HandleConcatenate(replacement); + } + return Status::OK(); } } @@ -1335,6 +1347,23 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( operands[pad_value_operand]->mutable_operand(0), padding_config)); return ReplaceInstruction(concatenate, pad); } + + if (absl::c_count(operands, operands[0]) == operands.size() && + operands[0]->shape().dimensions(concatenate_dimension) == 1) { + Shape new_shape = operands[0]->shape(); + absl::InlinedVector broadcast_dims; + for (int64 i = 0; i < new_shape.rank(); ++i) { + if (i == concatenate_dimension) { + continue; + } + broadcast_dims.push_back(i); + } + new_shape.DeleteDimension(concatenate_dimension); + return ReplaceInstruction( + concatenate, + MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(), + broadcast_dims, concatenate->shape())); + } return Status::OK(); } @@ -2479,6 +2508,20 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { if (ShapeUtil::IsZeroElementArray(operand_shape)) { return ReplaceInstruction(gather, MakeScalarLike(gather, 0)); } + + // Gathering from a scalar operand is simply a broadcast of that scalar + if (ShapeUtil::IsEffectiveScalar(operand_shape)) { + HloInstruction* new_operand = gather->mutable_operand(0); + if (operand_shape.rank()) { + TF_ASSIGN_OR_RETURN(new_operand, + MakeReshapeHlo(ShapeUtil::MakeScalarShape( + operand_shape.element_type()), + new_operand)); + } + HloInstruction* new_gather = + MakeBroadcastHlo(new_operand, {}, gather->shape()); + return ReplaceInstruction(gather, new_gather); + } // If the operand of a gather is very small, it is easier to fuse a // sequence of selects. const Shape& index_shape = gather->operand(1)->shape(); @@ -2667,6 +2710,17 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + { + HloInstruction* abs_operand; + if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && + !ShapeUtil::ElementIsComplex(abs_operand->shape())) { + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); + changed_ = true; + return Status::OK(); + } + } + { HloInstruction *convert_operand, *operand; // Mul(Convert(Pred), operand) => select(pred, operand, 0) @@ -2691,7 +2745,7 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y), // constant1*constant2) if (Match(multiply, - m::Multiply( + m::MultiplyAnyOrder( m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) { TF_ASSIGN_OR_RETURN(auto* product_of_constants, @@ -2713,6 +2767,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } } + { + HloInstruction *a, *c1, *c2; + // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2) + if (Match(multiply, + m::MultiplyAnyOrder( + m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { + TF_ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + if (ShapeUtil::IsScalar(product_of_constants->shape()) && + !ShapeUtil::IsScalar(multiply->shape())) { + product_of_constants = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), product_of_constants, {})); + } + + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, + a, product_of_constants)); + } + } + { HloInstruction *a, *b, *constant, *op; // Mul(Mul(a, constant1), Broadcast(b)) => @@ -3245,6 +3322,9 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; bool has_negative = false; + // Used to possibly split off the unchanged padding dimensions. + std::vector padding_dimensions; + int64 dimension_index = 0; for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { @@ -3253,12 +3333,93 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() != 0 || padding_dimension.edge_padding_high() != 0) { all_zero = false; + padding_dimensions.push_back(dimension_index); + } else if (padding_dimension.interior_padding()) { + padding_dimensions.push_back(dimension_index); } + dimension_index++; } if (all_zero) { - ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); - return Status::OK(); + if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) { + return Status::OK(); + } + } + + // The context of this optimization can be found at b/163617402 + // It tries to capture the case of pad(broadcast(x)), where + // x->shape().dimensions(), or broadcast(x)->dimensions(), is + // a subset of the padded dimensions in pad->config(), + // and the padded dimensions in pad->config() is in turn a strict + // subset of broadcast->shape().dimensions(). The combined op can be + // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends + // x with dimensions that need to be padded, and broadcast2 extends + // the result of padding to full dimensions. + // TODO(qyi): for future extensions: The condition for broadcast(x) + // ->dimensions() to be a subset of padded dimensions in pad->config() + // does not have to be strictly required, but it makes the calculation + // for optimization easier, so it is required by the current implementation. + // Only the second condition between the padded dimensions and the + // dimensions of the final shape have to be enforced for the optimization + // to make sense. If needed to remove the first constraint, the shape + // calculations across the implementation need to be re-adjusted. + auto pad_dims = padding_dimensions.size(); + if (pad_dims < dimension_index && + pad->operand(0)->opcode() == HloOpcode::kBroadcast && + pad->operand(0)->user_count() == 1 && + pad->operand(0)->operand(0)->shape().rank() <= pad_dims) { + // Check broadcast operand dimensions is a subset of pading_dimensions. + // If not, skip the optimization. + bool opt_is_valid = true; + std::vector broadcast_dimensions; + HloBroadcastInstruction* broadcast = + static_cast(pad->mutable_operand(0)); + for (auto broadcast_index : broadcast->dimensions()) { + bool found = false; + for (int i = 0; i < pad_dims; ++i) { + if (broadcast_index == padding_dimensions[i]) { + broadcast_dimensions.push_back(i); + found = true; + break; + } + } + if (!found) { + opt_is_valid = false; + break; + } + } + if (opt_is_valid) { + auto pad_shape = pad->shape(); + auto broadcast_shape = broadcast->shape(); + auto pad_shape1 = pad_shape; + auto broadcast_shape1 = broadcast_shape; + PaddingConfig pad_config; + for (int i = padding_dimensions.size() - 1; i >= 0; --i) { + int64 j = padding_dimensions[i]; + while (--dimension_index > j) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + } + while (--dimension_index >= 0) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + for (auto dimension_to_pad : padding_dimensions) { + auto dimension = pad_config.add_dimensions(); + *dimension = pad->padding_config().dimensions(dimension_to_pad); + } + *broadcast->mutable_shape() = broadcast_shape1; + *broadcast->mutable_dimensions() = broadcast_dimensions; + simplifier_->UpdateLayout(broadcast->mutable_shape()); + auto pad2 = + computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1)); + *pad2->mutable_padding_config() = pad_config; + simplifier_->UpdateLayout(pad2->mutable_shape()); + auto broadcast2 = computation_->AddInstruction( + HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions)); + return ReplaceInstruction(pad, broadcast2); + } } if (has_negative) { @@ -3293,7 +3454,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad->shape(), nonzero_pad->mutable_shape())); simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); - // Second, construct the slice instruction to perform the negative padding. + // Second, construct the slice instruction to perform the negative + // padding. std::vector start_indices; std::vector end_indices; std::vector strides; @@ -4140,8 +4302,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast)); } - // Convert a dynamic slice into a slice if all offsets are constant and the - // operand is not constant. If ev + // Convert a dynamic slice into a slice if all offsets are constant and the + // operand is not constant. if (operand->opcode() != HloOpcode::kConstant && absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, dynamic_slice->operands().end()), @@ -5109,10 +5271,10 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( if (!reverse_dimensions.empty()) { TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); } - TF_ASSIGN_OR_RETURN( - HloInstruction * new_convolution, - MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window, - swapped_dnums, precision_config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution, + MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, + /*batch_group_count=*/1, swapped_window, + swapped_dnums, precision_config)); convolution->SetupDerivedInstruction(new_convolution); TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f2a3404116..cabecec4eb8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -97,6 +97,14 @@ class AlgebraicSimplifierOptions { return enable_scalar_multiply_reduction_; } + // Also the algebraic simplifer to treat floating point values like real + // numbers. + void set_enable_floats_are_real(bool enable_floats_are_real) { + enable_floats_are_real_ = enable_floats_are_real; + } + + bool enable_floats_are_real() const { return enable_floats_are_real_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -158,6 +166,7 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; bool enable_scalar_multiply_reduction_{false}; + bool enable_floats_are_real_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 90ca44714f7..c4f3ea4087b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -117,6 +117,22 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { m::ConstantScalar(0.125)))); } +// (Abs(A)) * (Abs(A)) => (A*A) +TEST_F(AlgebraicSimplifierTest, SquareOfAbs) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[] parameter(0) + a = f32[] abs(p) + ROOT z = f32[] multiply(a, a) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); +} + // (A*C1) * (B*C2) => (A*B)*(C1*C2) TEST_F(AlgebraicSimplifierTest, MultiplyChain) { const char* kModuleStr = R"( @@ -140,6 +156,26 @@ TEST_F(AlgebraicSimplifierTest, MultiplyChain) { m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4))))); } +// (a*C1)*C2 => a*(C1*C2) +TEST_F(AlgebraicSimplifierTest, MultiplyChain2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + a = f32[] constant(2) + b = f32[] constant(4) + c = f32[] multiply(p0, a) + ROOT y = f32[] multiply(c, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2), + m::ConstantScalar(4))))); +} + // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==> // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant)))) TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) { @@ -2299,7 +2335,7 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99}); - Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); @@ -2346,10 +2382,15 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79}, /*limit_indices=*/{100, 89}, /*strides=*/{1, 1})); + // Can merge 'slice7' and 'slice8'. + HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89}, + /*limit_indices=*/{100, 99}, /*strides=*/{1, 1})); builder.AddInstruction(HloInstruction::CreateConcatenate( concat_shape, - {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1)); + {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8}, + 1)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); @@ -2364,6 +2405,12 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(), ShapeUtil::MakeShape(F32, {50, 30}))); EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40); + + // The operand 6 should be merge of 'slice7' and 'slice8', so its + // shape should have dimensions {50, 20} + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(), + ShapeUtil::MakeShape(F32, {50, 20}))); } // Test that a simplification which changes layouts is not performed if layout @@ -4823,6 +4870,25 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { EXPECT_EQ(root->slice_limits(0), 2); } +TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + p = f32[2,1,4] parameter(0) + ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, NegateNegate) { const char* hlo_string = R"( HloModule module @@ -5608,6 +5674,30 @@ INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); +TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) { + const char* hlo_string = R"( + HloModule repeat + + ENTRY main { + o = f32[1,1] parameter(0) + i = s32[100,2] parameter(1) + ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + offset_dims={}, + slice_sizes={1,1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { const char* hlo_string = R"( HloModule module @@ -6892,5 +6982,57 @@ TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) { GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant())))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant()))))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[32] parameter(0) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Parameter()), m::Constant()))))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner.cc b/tensorflow/compiler/xla/service/all_reduce_combiner.cc index 9d8f03c92ca..5fb4935a4b1 100644 --- a/tensorflow/compiler/xla/service/all_reduce_combiner.cc +++ b/tensorflow/compiler/xla/service/all_reduce_combiner.cc @@ -268,6 +268,11 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { VLOG(1) << "Running AllReduceCombiner with threshold of " << combine_threshold_in_bytes_ << " bytes"; + if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) { + VLOG(1) << "Skip AllReduceCombiner because the threshold is zero"; + return false; + } + if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce " "with constrained layouts"; diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc index 541006f04d5..18a0fdc1a70 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -31,27 +31,7 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( auto replication, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); - std::vector all_reduces_to_replace; - for (auto computation : module->computations()) { - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { - if (!inst->shape().IsArray()) { - // We currently do not change tuple-shaped all-reduce. - // Until XLA will support Token fed AllReduce(), the PyTorch client code - // uses a fake data token (constant) which relies on this pass to not - // optimize out (being fed within a tuple input). - continue; - } - if (inst->IsCrossReplicaAllReduce() && - replication->HloInstructionIsReplicatedAt(inst->operand(0), {})) { - all_reduces_to_replace.push_back(inst); - } - } - } - - bool changed = false; - if (all_reduces_to_replace.empty()) { - return changed; - } + std::vector> all_reduces_to_replace; // Returns the size of a replica group if all groups have the same size, or -1 // if they have different sizes. @@ -71,7 +51,40 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { return replica_group_size; }; - for (auto all_reduce : all_reduces_to_replace) { + for (auto computation : module->computations()) { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + if (!inst->shape().IsArray()) { + // We currently do not change tuple-shaped all-reduce. + // Until XLA will support Token fed AllReduce(), the PyTorch client code + // uses a fake data token (constant) which relies on this pass to not + // optimize out (being fed within a tuple input). + continue; + } + if (!inst->IsCrossReplicaAllReduce()) { + continue; + } + int64 group_size = get_replica_group_size(inst); + if (group_size == -1) { + continue; + } + if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) || + group_size == 1) { + all_reduces_to_replace.push_back({inst, group_size}); + } + } + } + + bool changed = false; + + for (auto all_reduce_and_group_size : all_reduces_to_replace) { + auto all_reduce = all_reduce_and_group_size.first; + const int64 replica_group_size = all_reduce_and_group_size.second; + if (replica_group_size == 1) { + TF_RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction( + all_reduce, all_reduce->mutable_operand(0))); + changed = true; + continue; + } if (all_reduce->to_apply()->instruction_count() != 3 || all_reduce->to_apply()->num_parameters() != 2) { continue; @@ -79,10 +92,6 @@ StatusOr AllReduceSimplifier::Run(HloModule* module) { HloInstruction* replacement; switch (all_reduce->to_apply()->root_instruction()->opcode()) { case HloOpcode::kAdd: { - int64 replica_group_size = get_replica_group_size(all_reduce); - if (replica_group_size == -1) { - continue; - } // Create the multiplier: // broadcast(convert_to_matching_type(s32 group size)) auto multiplier = diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc index 4914836b34a..1e938594cc3 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc @@ -167,5 +167,30 @@ test { m::Parameter(0), m::AllReduce(m::Parameter(1))))); } +TEST_F(AllReduceSimplifierTest, TrivialSubgroupAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); + AllReduceSimplifier simplifier(/*replica_count=*/8); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 23d2a9225a8..73210e6b3dc 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -159,19 +160,20 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { // Do not fold BF16 conversions for instructions related to tuples, entry and - // exit of a computation, fusion, convert, side-effecting instructions and - // control flow. - if (hlo->opcode() == HloOpcode::kTuple || // - hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kConstant || // - hlo->opcode() == HloOpcode::kParameter || // - hlo->opcode() == HloOpcode::kFusion || // - hlo->opcode() == HloOpcode::kBitcastConvert || // - hlo->opcode() == HloOpcode::kConvert || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional || // + // exit of a computation, fusion, convert, side-effecting instructions, + // in-place operations and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kBitcastConvert || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional || // + HloDataflowAnalysis::IsInPlaceOperation(hlo->opcode()) || // hlo->HasSideEffectNoRecurse()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index a0fe0eaa1d9..f9e19493a86 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -598,6 +598,31 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( type = F32; break; } + // In order to find aliases due to in-place operations, use + // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here, + // but this code works with HloModules that aren't ready yet to use + // HloAliasAnalysis (e.g., their computation graphs may not have been + // flattened yet). + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) { + if (operand_and_output_index.second == index) { + const HloUse& operand = operand_and_output_index.first; + for (const auto* value : + dataflow_ + ->GetValueSet(hlo->operand(operand.operand_number), + operand.operand_index) + .values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + CHECK_EQ(value_type, F32); + type = F32; + break; + } + } + } + // It's possible that a user has been changed from BF16 to F32 // during this final adjustment pass, so we need to check // AllUsersConsumeBF16() again. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 02d79025f1b..9a898833373 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -1156,4 +1156,30 @@ ENTRY entry { EXPECT_FALSE(PropagatePrecision(module.get())); } +TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) { + // This test is crafted so that the DUS has an f32 input (due to parameter) + // and bf16 output (due to dot). But we should enforce DUS operand 0 and + // output to get the same precision since it's an in-place operation. + const string module_str = R"( +HloModule Module + +ENTRY main { + param = f32[128,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3) + ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_FALSE(PropagatePrecision(module.get())); + + HloInstruction* dus = module->entry_computation()->GetInstructionWithName( + "dynamic-update-slice"); + EXPECT_FALSE(OutputsBF16(dus)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6cd58b86f0c..db34f054f35 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return true; } // namespace xla -Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { - // Try allocate same buffer for dynamic update slice's operand and output. - - // If memory_space_assignment is run and there is information about a color in - // preset assignments, don't merge those buffers. We expect - // memory_space_assignment to have merged these buffers. If - // memory_space_assignment didn't merge these buffers and have assigned - // different offsets to the operand and the output buffer, merging the buffers - // can cause memory corruption if memory_space_assignment assigned a different - // buffer at the same offset. - absl::flat_hash_set excluded_colors; - if (preset_assignments_) { - for (const auto& color_and_info : - preset_assignments_->assignment_informations()) { - excluded_colors.insert(color_and_info.first); - } - } - - // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule - // to operations that can be done in place. - for (HloComputation* computation : assignment->module().computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice || - (instruction->opcode() == HloOpcode::kFusion && - (instruction->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice)))) { - continue; - } - if (instruction->parent()->IsFusionComputation()) { - continue; - } - if (instruction->operand_count() == 0) { - continue; - } - - // The operand can't share the same buffer with the user based on dataflow - // analysis. - if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser( - instruction->mutable_operand(0), {}, instruction, {})) { - continue; - } - HloBuffer& instruction_buffer = - assignment->alias_analysis().GetUniqueBufferAt(instruction, {}); - - HloBuffer& operand_buffer = - assignment->alias_analysis().GetUniqueBufferAt( - instruction->operand(0), {}); - - // The instruction or operand color is excluded because it was assigned by - // memory_space_assignment. - if (excluded_colors.contains(instruction_buffer.color()) || - excluded_colors.contains(operand_buffer.color())) { - continue; - } - - // Already have the same buffer. No need to merge those. - if (instruction_buffer.id() == operand_buffer.id()) { - continue; - } - - // Do not perform in-place dynamic update slice if the operand buffer is - // read-only. - if (HloBufferIsReadOnly(operand_buffer)) { - continue; - } - - bool interfere = false; - - for (const HloValue* instruction_value : instruction_buffer.values()) { - for (const HloValue* operand_value : operand_buffer.values()) { - if (assignment->hlo_ordering().MayInterfere( - *instruction_value, *operand_value, - assignment->dataflow_analysis())) { - interfere = true; - break; - } - } - } - if (interfere) { - continue; - } - if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) { - continue; - } - if (instruction_buffer.color() != operand_buffer.color()) { - continue; - } - VLOG(3) << "Merging inplace " << instruction_buffer << " and " - << operand_buffer; - assignment->alias_analysis().MergeBuffers(instruction_buffer, - operand_buffer); - } - } - return Status::OK(); -} - Status BufferAssigner::AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, absl::flat_hash_map>>(); - algorithms->push_back(absl::make_unique( - alignment, GlobalDecreasingSizeBestFitHeap::kSpatial)); - algorithms->push_back(absl::make_unique( - alignment, GlobalDecreasingSizeBestFitHeap::kTemporal)); - return absl::make_unique(std::move(algorithms)); + auto algorithms = absl::make_unique< + std::vector>>>(); + algorithms->push_back( + absl::make_unique>( + alignment, GlobalDecreasingSizeBestFitHeap::kSpatial)); + algorithms->push_back( + absl::make_unique>( + alignment, GlobalDecreasingSizeBestFitHeap::kTemporal)); + return absl::make_unique>( + std::move(algorithms)); }; if (run_whole_module_heap_simulation) { @@ -1461,7 +1368,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run( get_heap_algorithm(alignment), assignment->module(), schedule, assignment->alias_analysis(), assignment->buffer_size_, options)); @@ -1487,7 +1394,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Options options; options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, *instruction_sequence, assignment->alias_analysis(), @@ -1582,7 +1489,7 @@ std::vector ComputePeakMemoryLogicalBuffers( } // namespace void BufferAssigner::AssignBuffersFromHeapSimulator( - const HeapSimulator::Result& result, BufferAssignment* assignment, + const HeapSimulator::Result& result, BufferAssignment* assignment, BufferValue::Color color) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = @@ -1651,7 +1558,6 @@ StatusOr> BufferAssigner::CreateAssignment( VLOG(3) << "After coloring:"; XLA_VLOG_LINES(3, assignment->alias_analysis().dataflow_analysis().ToString()); - TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get())); std::vector thread_local_computations; std::vector global_computations; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 50a4750601b..dfde46ca4b1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -635,10 +635,6 @@ class BufferAssigner { absl::flat_hash_set* assigned_buffers, BufferAssignment* assignment); - // Promotes operations (DUS, scatter) to be done in place: If an operation can - // be done in place, merge its buffer with its operand buffer. - Status MergeInplaceOpBuffers(BufferAssignment* assignment); - // Assigns a single hlo buffer to an HLO allocation. Status AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, @@ -661,9 +657,9 @@ class BufferAssigner { // Uses the results of the heap simulator to create a single allocation, with // LogicalBuffers packed to specific offsets. - void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, - BufferAssignment* assignment, - LogicalBuffer::Color color); + void AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, + BufferAssignment* assignment, LogicalBuffer::Color color); // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bc024f7144b..b49ca649f9a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1925,8 +1925,10 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* parameter = m->entry_computation()->GetInstructionWithName("get-tuple-element.4"); - HloInstruction* dus = + HloInstruction* dus1 = m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5"); + HloInstruction* dus2 = + m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9"); auto buffers = RunBufferAssignment(m.get()); @@ -1934,8 +1936,10 @@ ENTRY main { const BufferAllocation& parameter_alloc = GetTopLevelAllocation(*buffers, parameter); - const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus); - EXPECT_NE(parameter_alloc, dus_alloc); + const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1); + EXPECT_EQ(parameter_alloc, dus1_alloc); + const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2); + EXPECT_EQ(parameter_alloc, dus2_alloc); } } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index 20576cdc52d..ffb0fb4e6ef 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -35,8 +35,6 @@ limitations under the License. namespace xla { -namespace { - // The Cholesky–Banachiewicz algorithm. See // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms // for a description. @@ -54,78 +52,70 @@ namespace { // l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. -std::pair CholeskyUnblocked( +StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - auto result = [&]() -> StatusOr> { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = a_shape.rank(); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - 2); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - 2); - auto matrix_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims); + auto matrix_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims); - XlaOp l = ZerosLike(a); + XlaOp l = ZerosLike(a); - // Construct the for loop body to iterate over rows. - auto body_fn = - [&](XlaOp i, absl::Span loop_vars, - XlaBuilder* body_builder) -> StatusOr> { - std::vector row_shape_dims(major_dims.begin(), major_dims.end()); - std::vector col_shape_dims(major_dims.begin(), major_dims.end()); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - auto seen_error = loop_vars[2]; - auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 1); - auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 2); + // Construct the for loop body to iterate over rows. + auto body_fn = [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + auto seen_error = loop_vars[2]; + auto iota_row = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1); + auto iota_col = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2); - auto mask_pred = Ge(iota_col, iota_row); - mask_pred = And(mask_pred, Eq(iota_row, i)); - auto mask_zeros = - Zeros(body_builder, - ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); - // L * L.T, This matrix has of a lot of multiplying with zero - // (namely, L[:, j:] = 0) and redundant computation, but it is faster - // than slice. - auto l_square = BatchDot(body_l, false, body_l, true, precision); + auto mask_pred = Ge(iota_col, iota_row); + mask_pred = And(mask_pred, Eq(iota_row, i)); + auto mask_zeros = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); + // L * L.T, This matrix has of a lot of multiplying with zero + // (namely, L[:, j:] = 0) and redundant computation, but it is faster + // than slice. + auto l_square = BatchDot(body_l, false, body_l, true, precision); - // A - L*L.T - l_square = body_a - l_square; - auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); - l_ii = Sqrt(l_ii); - // L = (A - L*L.T) / l_ii * mask + L - body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; + // A - L*L.T + l_square = body_a - l_square; + auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); + l_ii = Sqrt(l_ii); + // L = (A - L*L.T) / l_ii * mask + L + body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; - seen_error = - Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); + seen_error = + Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); - return std::vector{body_a, body_l, seen_error}; - }; + return std::vector{body_a, body_l, seen_error}; + }; - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, - "unblocked", builder)); + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, + "unblocked", builder)); - return std::make_pair(cholesky_while[1], cholesky_while[2]); - }(); - if (!result.ok()) { - XlaOp error = builder->ReportError(result.status()); - return {error, error}; - } - return result.ValueOrDie(); + return std::make_pair(cholesky_while[1], cholesky_while[2]); } -XlaOp BuildCholesky(XlaOp a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -162,6 +152,7 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, XlaOp seen_error = ConstantR0(builder, false); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); + auto panel = SliceInMinorDims(a, {i, i}, {n, i + k}); if (i > 0) { // TODO(phawkins): consider implementing SYRK for the diagonal part of // the panel. @@ -169,28 +160,34 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, false, rhs, true, precision); - auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); - a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + panel = panel - delta; } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto x = SliceInMinorDims(panel, {0, 0}, {k, k}); XlaOp factorized; + // TODO(b/167896062): A failure in one element of a batch shouldn't fail + // other elements. XlaOp factorized_error; - std::tie(factorized, factorized_error) = CholeskyUnblocked(x, precision); + if (k == 1) { + factorized = Sqrt(x); + factorized_error = Any(IsNan(factorized)); + } else { + TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision)); + std::tie(factorized, factorized_error) = tile_output; + } seen_error = Or(seen_error, factorized_error); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = - TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*unit_diagonal=*/false, - /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + auto update = TriangularSolve( + factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}), + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } @@ -199,8 +196,6 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, }); } -} // namespace - bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCholesky; } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.h b/tensorflow/compiler/xla/service/cholesky_expander.h index d2958db1b8c..ee8531d0f48 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.h +++ b/tensorflow/compiler/xla/service/cholesky_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -31,7 +32,13 @@ class CholeskyExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual StatusOr> CholeskyUnblocked( + XlaOp a, PrecisionConfig::Precision precision); + private: + XlaOp BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision); + // Mapping from op signatures to existing computations. absl::flat_hash_map computation_cache_; }; diff --git a/tensorflow/compiler/xla/service/comparison_expander.cc b/tensorflow/compiler/xla/service/comparison_expander.cc new file mode 100644 index 00000000000..5c88ff8cae2 --- /dev/null +++ b/tensorflow/compiler/xla/service/comparison_expander.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/comparison_expander.h" + +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +HloInstruction* BitcastConvertFloatingPointToIntegral( + HloComputation* computation, HloInstruction* value, + const Shape& signed_shape, const Shape& unsigned_shape, + HloInstruction* zero, HloInstruction* max_value) { + // Switch from a floating point value to a integer value in such a way that + // when using the integer value to compare, we get the same result for normal + // values, and -Nan is treated as the smallest value, and Nan is treated as + // the largest value. + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? numeric_limits::max() - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + // Note that in order to avoid -x to overflow, we calculate + // numeric_limits::max() - x as unsigned, and then convert back to + // signed. + auto signed_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(signed_shape, value)); + auto unsigned_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(unsigned_shape, value)); + auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value)); + flipped_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(signed_shape, flipped_value)); + auto compare_shape = signed_shape; + compare_shape.set_element_type(PRED); + auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare( + compare_shape, signed_value, zero, ComparisonDirection::kLt)); + return computation->AddInstruction( + HloInstruction::CreateTernary(signed_shape, HloOpcode::kSelect, + is_negative, flipped_value, signed_value)); +} + +bool ComparisonExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + if (HloCompareInstruction* compare = + dynamic_cast(instruction)) { + HloInstruction* lhs = instruction->operands()[0]; + if (compare->type() == Comparison::Type::kFloatTotalOrder && + primitive_util::IsFloatingPointType(lhs->shape().element_type())) { + return true; + } + } + return false; +} + +StatusOr ComparisonExpander::ExpandInstruction( + HloInstruction* instruction) { + CHECK(instruction->opcode() == HloOpcode::kCompare); + HloCompareInstruction* compare = + static_cast(instruction); + CHECK(compare->type() == Comparison::Type::kFloatTotalOrder); + HloComputation* computation = instruction->parent(); + HloInstruction* lhs = instruction->operands()[0]; + HloInstruction* rhs = instruction->operands()[1]; + Shape compare_shape = lhs->shape(); + PrimitiveType compare_type = compare_shape.element_type(); + CHECK(primitive_util::IsFloatingPointType(compare_type)); + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + compare_shape.set_element_type(compare_type); + lhs = computation->AddInstruction( + HloInstruction::CreateConvert(compare_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateConvert(compare_shape, rhs)); + } + + int64 bit_width = primitive_util::BitWidth(compare_type); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + PrimitiveType unsigned_type = + primitive_util::UnsignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = compare_shape; + signed_shape.set_element_type(signed_type); + auto unsigned_shape = compare_shape; + unsigned_shape.set_element_type(unsigned_type); + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast( + signed_shape, zero_value, zero_value->shape().dimensions())); + auto max_signed = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + auto max_shape = max_signed->shape(); + max_shape.set_element_type(unsigned_type); + auto max_unsigned = computation->AddInstruction( + HloInstruction::CreateConvert(max_shape, max_signed)); + auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast( + unsigned_shape, max_unsigned, max_shape.dimensions())); + lhs = BitcastConvertFloatingPointToIntegral( + computation, lhs, signed_shape, unsigned_shape, zero_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral( + computation, rhs, signed_shape, unsigned_shape, zero_value, max_value); + auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( + instruction->shape(), lhs, rhs, compare->direction(), + Comparison::Type::kSigned)); + VLOG(2) << "New comparison instruction for total order:" + << new_compare->ToString() << "\n"; + return new_compare; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/comparison_expander.h b/tensorflow/compiler/xla/service/comparison_expander.h new file mode 100644 index 00000000000..df8b5dc0137 --- /dev/null +++ b/tensorflow/compiler/xla/service/comparison_expander.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +// A pass which performs expansion of the comparison operator to support total +// order comparison of floating point numbers. +class ComparisonExpander : public OpExpanderPass { + public: + explicit ComparisonExpander() = default; + ~ComparisonExpander() override = default; + absl::string_view name() const override { return "comparison-expander"; } + + private: + // Returns `true` if `instruction` should be expanded by this pass. + bool InstructionMatchesPattern(HloInstruction* instruction) override; + // Returns a replacement for `instruction`, or nullptr if no replacement is + // needed (e.g. only the to_apply subcomputation of the instruction was + // modified). + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index cdda0aeb925..bd72ad22cb2 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -29,11 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -95,12 +97,23 @@ class BoundaryVisitor { absl::flat_hash_set visited_; }; +template +int64 CountNonLeafOps(const OpCollection& ops) { + absl::flat_hash_set op_set; + for (auto op : ops) { + if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { + op_set.insert(op); + } + } + return op_set.size(); +} + // Returns estimation of potential reuses carried by a given pair of // instructions. Use different integers to classify different levels // of reuses This is used as a placeholder only, assuming all // instructions can be fused to enable data reuses int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { - VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: " + VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: " << op->ToString() << "=>" << user->ToString() << "\n"; switch (user->opcode()) { case HloOpcode::kGetTupleElement: @@ -114,9 +127,11 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { case HloOpcode::kConstant: case HloOpcode::kGetTupleElement: return 0; + case HloOpcode::kConditional: + return 10; default: // Assume fusion will not happen anyway if user count > 1) - if (op->user_count() > 1) { + if (CountNonLeafOps(op->users()) > 1) { return 0; } return 10; @@ -432,7 +447,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( if (to_move_out.empty()) { return false; } - VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n"; + VLOG(1) << "Modifying code--number of boundaries to move out:" + << to_move_out.size() << "\n"; HloComputation* conditional_parent = conditional->parent(); // save the old users before add new conditional user instructions std::vector old_conditional_users = conditional->users(); @@ -441,7 +457,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( absl::flat_hash_map hoisted_instructions; // Insert GetTupleElement before the instructions whose operands might still // be within the conditional. - VLOG(2) << "before opt:" + VLOG(1) << "before opt:" << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; int64 op_index = 0; @@ -470,16 +486,22 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( HloInstruction* old_root = conditional->branch_computation(0)->root_instruction(); for (auto user_instr : old_conditional_users) { + VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n"; CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement); auto tuple_opd = static_cast(user_instr); int64 index = tuple_opd->tuple_index(); + CHECK(old_root->operands().size() > index); HloInstruction* old_opd = old_root->operands()[index]; + CHECK(ContainsKey(hoisted_instructions, old_opd)); HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0]; CHECK(old_opd != nullptr); CHECK(new_opd != nullptr); + VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n"; TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); } + VLOG(2) << "Done changing conditional users\n" + << conditional_parent->ToString() << "\n"; // Create tuple element within each branch and set it as root. int64 branch_count = conditional->branch_count(); for (int i = 0; i < branch_count; i++) { @@ -487,9 +509,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( std::vector elements; for (auto b1 : new_boundaries) { HloInstruction* op = b1.operands()[i]; - VLOG(1) << "branch count=" << i << "\n"; CHECK(op != nullptr); - VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n"; + VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n"; elements.push_back(op); } HloInstruction* tuple = @@ -498,8 +519,16 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( VLOG(2) << "computation is :" << computation->ToString() << "\n"; // Remove hoisted instructions from the branches. for (auto b2 : to_move_out) { - VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; - TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.operands()[i])); + auto instr_to_remove = b2.operands()[i]; + // Double check to make sure it is safe to delete the instruction. + // Complications may arise due to some operations in the alternative + // branches (branches 1..n) being placed into the boundaries multiple + // times. + if (!computation->IsMarkedAsDead(instr_to_remove) && + instr_to_remove->user_count() == 0) { + VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove)); + } } } // Change conditional instruction shape to the shape of the new root. @@ -507,7 +536,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( conditional->branch_computation(0)->root_instruction(); *conditional->mutable_shape() = new_root->shape(); // - VLOG(2) << "done moving instructions out of branches\n" + VLOG(1) << "done moving instructions out of branches\n" << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; return true; @@ -520,48 +549,89 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( if (to_move_in.empty()) { return false; } - VLOG(1) << "number of boundaries to move in:" << to_move_in.size() << "\n"; - HloComputation* conditional_parent = conditional->parent(); - VLOG(2) << "before opt:" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + VLOG(1) << "Modifying code---number of boundaries to move in:" + << to_move_in.size() << "\n"; + VLOG(1) << "before opt:" + << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) << "\n"; // Mapping instructions to be moved to their new representations. absl::flat_hash_map hoisted_instructions; int64 to_move_in_size = to_move_in.size(); int64 branch_count = conditional->branch_count(); - int64 op_index = conditional->shape().tuple_shapes_size(); - // Map conditional to its old root, then create a new root instruction in each - // branch. - Boundary b(Boundary::Position::kInsideBranch); + // Number of old conditional entries still to be used outside. + // If conditional shape is not tuple, will create a tuple and use subscript + // 0 to save the old operand being used. + int64 op_index = conditional->shape().IsTuple() + ? conditional->shape().tuple_shapes_size() - 1 + : 0; + HloGetTupleElementInstruction* tuple_use = + dynamic_cast(to_move_in[0].operands()[0]); + int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1; + VLOG(2) << "Tuple use index = " << use_index << "\n"; + // Use to map the tuple_use instruction to its operand; + Boundary b_opd_use(Boundary::Position::kInsideBranch); + Boundary b_old_root(Boundary::Position::kInsideBranch); + // Create a new root instruction in each branch. for (int i = 0; i < branch_count; i++) { auto computation = conditional->branch_computation(i); auto old_root = computation->root_instruction(); - b.mutable_operands().push_back(old_root); - HloInstruction* new_root = nullptr; + b_old_root.mutable_operands().push_back(old_root); + std::vector operands; if (old_root->opcode() == HloOpcode::kTuple) { - new_root = computation->AddInstruction(old_root->Clone()); - } else { - std::vector operands; - if (!old_root->shape().IsTuple()) { - operands.push_back(old_root); - } else { - const Shape& old_shape = old_root->shape(); - for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) { - auto element = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - old_shape.tuple_shapes(i), old_root, i)); - operands.push_back(element); + // Use operands of old_root directly, so old_root can be removed later. + for (int i = 0; i < old_root->operand_count(); ++i) { + if (i != use_index) { + operands.push_back(old_root->operands()[i]); + } else { // Map conditional use to the tuple operand. + b_opd_use.mutable_operands().push_back(old_root->operands()[i]); } } - new_root = - computation->AddInstruction(HloInstruction::CreateTuple(operands)); + } else if (old_root->shape().IsTuple()) { + // If old_root is not a kTuple but has tuple shape, elements within the + // tuple must be extracted first to be used by the new instructions. + const Shape& old_shape = old_root->shape(); + for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) { + auto element = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + old_shape.tuple_shapes(i), old_root, i)); + if (i != use_index) { + operands.push_back(element); + } else { + b_opd_use.mutable_operands().push_back(element); + } + } + } else { + // If old_root is not a tuple and does not have tuple shape, use it + // to replace the conditional directly in the new computation. + b_opd_use.mutable_operands().push_back(conditional); } + + HloInstruction* new_root = + computation->AddInstruction(HloInstruction::CreateTuple(operands)); VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; - computation->set_root_instruction(new_root); + computation->set_root_instruction(new_root, + /*accept_different_shape*/ true); + if (old_root->opcode() == HloOpcode::kTuple) { + TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root)); + } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } - hoisted_instructions[conditional] = b; - for (int64 i = 0; i < to_move_in_size; i++) { + // Update get tuple element index of the conditional. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() > use_index) { + user->set_tuple_index(user->tuple_index() - 1); + } + } + } + hoisted_instructions[conditional] = b_old_root; + int64 cp_start = 0; + if (use_index >= 0) { + hoisted_instructions[tuple_use] = b_opd_use; + cp_start = 1; + } + for (int64 i = cp_start; i < to_move_in_size; i++) { Boundary b_to_move = to_move_in[i]; HloInstruction* op = b_to_move.operands()[0]; CHECK(op != nullptr); @@ -591,12 +661,12 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( } if (to_be_used_outside) { // Modify uses of instructions outside of the conditionals - HloInstruction* gtr = conditional_parent->AddInstruction( + HloInstruction* gtr = conditional->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(op->shape(), conditional, op_index++)); TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); - if (conditional_parent->root_instruction() == op) { - conditional_parent->set_root_instruction(gtr); + if (conditional->parent()->root_instruction() == op) { + conditional->parent()->set_root_instruction(gtr); } } } @@ -606,8 +676,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( HloInstruction* new_root = conditional->branch_computation(0)->root_instruction(); *conditional->mutable_shape() = new_root->shape(); - VLOG(2) << "Before removing instructions:" << conditional_parent->ToString() - << "\n"; + VLOG(2) << "Before removing instructions:" + << conditional->parent()->ToString() << "\n"; // Remove hoisted instructions from the branches. for (int64 i = to_move_in_size - 1; i >= 0; i--) { Boundary boundary_to_move_in = to_move_in[i]; @@ -616,10 +686,20 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( for (auto user : op->users()) { VLOG(2) << "Has User: " << user->ToString() << "\n"; } - TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(op)); + TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); } - VLOG(2) << "Done moving instructions inside branches\n" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + + // Reset shapes of user gtes to the new shape. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + *user->mutable_shape() = + conditional->shape().tuple_shapes(user->tuple_index()); + } + } + } + VLOG(1) << "Done moving instructions inside branches\n" + << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) << "\n"; return true; } @@ -631,6 +711,7 @@ class GroupConnectedBoundaries { HloInstruction* conditional_; HloComputation* conditional_parent_; bool is_layout_sensitive_; + // Instructions that have been visited but are not going to be moved. absl::flat_hash_set visited_; public: @@ -639,7 +720,7 @@ class GroupConnectedBoundaries { : conditional_(conditional), conditional_parent_(conditional->parent()), is_layout_sensitive_(is_layout_sensitive) {} - // Returns true if `instruction` is worth hoisting out. + // Returns true if `instruction` is worth hoisting. bool WorthHoisting(HloInstruction* instruction) { // This is needed for the "moving-in" transformation, to prevent the root // of the parent computation (which contains the conditional) to be moved @@ -663,13 +744,14 @@ class GroupConnectedBoundaries { case HloOpcode::kReshape: return true; default: - VLOG(1) << "Instruction is convert and its operand is not know to " + VLOG(2) << "Instruction is convert and its operand is not know to " "be worth hoisting\n"; return false; } case HloOpcode::kAllReduce: case HloOpcode::kAdd: case HloOpcode::kPower: + case HloOpcode::kCopy: case HloOpcode::kConstant: case HloOpcode::kSubtract: case HloOpcode::kMultiply: @@ -680,24 +762,28 @@ class GroupConnectedBoundaries { case HloOpcode::kGetTupleElement: return true; default: - VLOG(1) << "Instruction is not known to be worth hoisting\n"; + VLOG(2) << "Instruction is not known to be worth hoisting\n"; return false; } } int64 ReusesBeforeBoundary(HloInstruction* user) { int64 reuses = 0; for (auto op : user->operands()) { + // The operand must be an instruction that is not going to be moved (if + // user is inside the conditional); otherwise it must be the conditional + // itself and its user must be outside of the conditional. + if (!ContainsKey(visited_, op) && op != conditional_) { + continue; + } // Only consider single-user cases as reuseable. - if (ContainsKey(visited_, op) && op->user_count() == 1) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->user_count() == 1) { + reuses += ReusesCarriedBy(op, user->users()[0]); + } else if (op->user_count() == 1) { reuses += ReusesCarriedBy(op, user); - } else if (op->opcode() == HloOpcode::kConditional && - user->opcode() == HloOpcode::kGetTupleElement) { - if (user->user_count() == 1) { - reuses += ReusesCarriedBy(op, user->users()[0]); - } } } - VLOG(1) << "Reuses before instruction " << user->ToString() << ":" << reuses + VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses << "\n"; return reuses; } @@ -735,7 +821,7 @@ class GroupConnectedBoundaries { } else if (ContainsKey(visited_, op)) { reuses += ReusesCarriedBy(user, op); } - VLOG(1) << "reuses after instruction " << user->ToString() << ":" + VLOG(2) << "reuses after instruction " << user->ToString() << ":" << reuses << "\n"; return reuses; } @@ -744,7 +830,8 @@ class GroupConnectedBoundaries { int64 BenefitForMovingBoundaries(const std::vector& boundaries) { int64 reuses_before = 0, reuses_after = 0; - if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch()) { + if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() && + boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) { // The only boundary of moving-in is the get_tuple_element op. return -1; } @@ -754,16 +841,16 @@ class GroupConnectedBoundaries { continue; } reuses_before += ReusesBeforeBoundary(op); - VLOG(1) << "Reuses before boundary so far: " << reuses_before << "\n"; + VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n"; reuses_after += ReusesAfterBoundary(op); - VLOG(1) << "Reuese after boundary so far : " << reuses_after << "\n"; + VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n"; } if (reuses_after == 0 && reuses_before == 0) { return -1; } else if (boundaries[0].IsInsideBranch()) { return reuses_after - reuses_before; } else { - return reuses_before - reuses_after; + return reuses_before - reuses_after - 1; } } @@ -779,17 +866,6 @@ class GroupConnectedBoundaries { } return b2; } - int64 CountNonLeafOps(const xla::HloInstruction::InstructionVector& ops) { - int64 count = 0; - absl::flat_hash_set op_set; - for (auto op : ops) { - if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { - count++; - op_set.insert(op); - } - } - return count; - } // This function is reused both for moving the boundary outside or into a // conditional. As the result, the readability is somewhat compromised. // It might be nice to refactor this function to factor the outside-inside @@ -800,12 +876,12 @@ class GroupConnectedBoundaries { visitor.AddToWorkList(boundary); while (visitor.HasNextBoundary()) { Boundary b = visitor.PopNextBoundary(); - VLOG(1) << "visiting boundary " << b.ToString() << "\n"; + VLOG(2) << "visiting boundary " << b.ToString() << "\n"; if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( b.operands(), is_layout_sensitive_)) && WorthHoisting(b.operands()[0])) { connected_boundaries_.push_back(b); - VLOG(1) << "boundary can be moved\n"; + VLOG(2) << "boundary can be moved\n"; int64 operand_count = (b.IsInsideBranch()) ? b.operands()[0]->operand_count() : b.operands()[0]->users().size(); @@ -829,20 +905,21 @@ class GroupConnectedBoundaries { } } } else { - VLOG(1) << "boundary cannot be moved\n"; + VLOG(2) << "boundary cannot be moved\n"; visited_.insert(b.operands()[0]); new_boundaries_.push_back(b); } } } - std::vector BoundariesToMoveInOrOut(const Boundary& b) { + std::vector BoundariesToMoveInOrOut(HloInstruction* conditional, + const Boundary& b) { // At the beginning of optimization, a conditional itself is added to a // worklist. Here the conditional is expanded into two sets of boundaries: // the first set contains the boundary that is inside branches and // contains the root of all branches; the second set of boundaries // contains all the users of the conditional. HloInstruction* inst = b.operands()[0]; - if (inst->opcode() == HloOpcode::kConditional) { + if (inst == conditional) { int branch_count = inst->branch_count(); // Add conditional roots as a new boundary to visit. Boundary boundary_in(Boundary::Position::kInsideBranch); @@ -873,10 +950,11 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( HloInstruction* conditional, const Boundary& cur_boundary, std::vector& to_move, std::vector& new_boundaries) { GroupConnectedBoundaries connect(conditional, is_layout_sensitive_); - auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary); + auto move_in_or_out = + connect.BoundariesToMoveInOrOut(conditional, cur_boundary); if (!move_in_or_out.empty()) { auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out); - VLOG(1) << "benefit of moving in or out " + VLOG(2) << "benefit of moving in or out " << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n"; if (benefit >= 0) { new_boundaries.clear(); @@ -896,19 +974,62 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( } StatusOr ConditionalCodeMotion::Run(HloModule* module) { + bool changed = false; + bool cleanup_changed = false; + { + HloPassPipeline subpipeline("before_conditional_code_motion"); + subpipeline.AddPass(/*is_layout_sensitive=*/is_layout_sensitive_); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); + cleanup_changed |= cleanup_changed_now; + } // Gather all the conditional ops in the module ahead of time, to avoid // potential complications of modifying the code that affecting traversal. std::vector conditional_ops; + // Track how many times each branch computation is shared. + absl::flat_hash_map conditional_computations; for (auto* comp : module->MakeComputationPostOrder()) { for (auto* instr : comp->MakeInstructionPostOrder()) { if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); + int branch_count = instr->branch_count(); + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = instr->branch_computation(i); + if (ContainsKey(conditional_computations, branch_i)) { + conditional_computations[branch_i]++; + } else { + conditional_computations[branch_i] = 0; + } + } + if (instr->shape().IsTuple()) { + bool can_change_tuple_shape = true; + for (auto user : instr->users()) { + VLOG(2) << "user is : " << user->ToString() << "\n"; + if (user->opcode() != HloOpcode::kGetTupleElement) { + can_change_tuple_shape = false; + } + } + if (can_change_tuple_shape) { + conditional_ops.push_back(instr); + } + } else { + conditional_ops.push_back(instr); + } } } } - bool changed = false; for (HloInstruction* conditional : conditional_ops) { + int branch_count = conditional->branch_count(); + // check for shared conditional computations + bool conditional_is_shared = false; + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = conditional->branch_computation(i); + if (conditional_computations[branch_i] > 0) { + conditional_is_shared = true; + break; + } + } + // Boundaries to move out or to move into the branches. std::vector to_move_out, to_move_in, new_boundaries; // The conditional is moved into a worklist as the seed (starting point). @@ -926,6 +1047,33 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { Boundary boundary = visitor.PopNextBoundary(); VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); + if (d != Decision::kNoChange && conditional_is_shared) { + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = conditional->branch_computation(i); + if (conditional_computations[branch_i] > 0) { + // Cloning is absolutely needed if the computation is shared by + // different branches, but the cloning can be potentially avoided + // if the sharing is only among branches of the same conditional. + // If cloning these branches causes a problem due to space issues, + // a fix can pass a vector of unique branches to the actual + // transformations, as an alternative representation of the + // conditional branches to be modified. Right now we assume the + // overhead of cloning is minimal since later stages of the compiler + // inline all the computations anyway. + HloComputation* clone_i = + conditional->parent()->parent()->AddEmbeddedComputation( + branch_i->Clone()); + conditional->set_branch_computation(i, clone_i); + conditional_computations[branch_i]--; + } + } + to_move.clear(); + next_boundary.clear(); + VLOG(2) << "Cloned branches as needed: " << conditional->ToString() + << "\n"; + // Need to reanalyze the cloned code to generate correct result. + d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); + } switch (d) { case Decision::kMoveOutOfBranch: VLOG(2) << "Decision is move out of branch\n"; @@ -961,22 +1109,14 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { MoveInstructionIn(conditional, to_move_in, new_boundaries)); VLOG(2) << "moving in result:" << result << "\n"; changed |= result; - } - } - // handling convert rematerialization/hoisting - if (!changed && pursue_full_conditional_code_motion_) { - std::vector conditional_ops; - for (auto* comp : module->MakeComputationPostOrder()) { - for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); - } - } - } - for (HloInstruction* conditional_op : conditional_ops) { + } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) { + // Invoke special handling for convert rematerialization/hoisting + // We need to make sure no sharing is present in the branches because no + // cloning has been done by the earlier analysis. + // TOOD[b/165848866]: extend solution to handle cloning for special move. TF_ASSIGN_OR_RETURN( bool convert_result, - ConvertSpecialMove(conditional_op, is_layout_sensitive_)); + ConvertSpecialMove(conditional, is_layout_sensitive_)); changed |= convert_result; } } @@ -986,8 +1126,11 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); - TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); - changed |= cleanup_changed; + TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); + cleanup_changed |= cleanup_changed_now; + } + if (cleanup_changed) { + VLOG(2) << "subpipeline cleanup have modified code\n"; } return changed; } diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b0a6ba92f48..3b772221446 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -158,6 +158,44 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); } +TEST_F(ConditionalCodeMotionTest, ConditionalShapeNotMutable) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) + %sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + ROOT result = (bf16[2,512,364]{2,1,0}, (bf16[2,512,364]{2,1,0})) tuple(get-first-index, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOut) { absl::string_view hlo_string = R"( @@ -580,6 +618,347 @@ ENTRY main { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } + +TEST_F(ConditionalCodeMotionTest, NoMoveInWithMultipleGTE) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +on_false { + arg_tuple.2 = (f32[10]) parameter(0) + get-tuple-element.2 = f32[10] get-tuple-element(arg_tuple.2), index=0 + mul.1 = f32[10] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.4 = (f32[10]) tuple(mul.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[10] get-tuple-element(conditional), index=0 + get-first-index.2 = f32[10] get-tuple-element(conditional), index=0 + pow.1 = f32[10] power(get-first-index, get-first-index) + ROOT tuple.3 = (f32[10], f32[10]) tuple(pow.1, get-first-index.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Tuple(op::Power(), op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + get-first-index = f32[10] get-tuple-element(conditional), index=0 + ROOT pow.1 = f32[10] power(get-first-index, get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + ROOT add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = f32[10] + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + ROOT pow.1 = f32[10] power(conditional, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithEmptyBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +branch2 { + ROOT arg_tuple.1 = (f32[10]) parameter(0) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=branch1, + false_computation=branch2 + get-first-index = f32[10] get-tuple-element(conditional), index=0 + ROOT pow.1 = f32[10] power(get-first-index, get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 4); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleParameter) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg.1 = f32[10] parameter(0) + ROOT add.1 = f32[10] add(arg.1, arg.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = f32[10] parameter(1) + tuple.2 = f32[10] parameter(2) + conditional = f32[10] + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + ROOT pow.1 = f32[10] power(conditional, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 4); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 4); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0) + constant.1 = s32[] constant(4) + get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0 + add.1 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1 + slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2), + slice={[0:4:1], [0:3:1]} + constant.2 = f32[] constant(0.0) + ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2) +} + +branch2 { + arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0 + copy.1 = s32[] copy(get-tuple-element.3) + get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1 + copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4) + constant.2 = f32[] constant(0.0) + ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1) + tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2) + conditional = (f32[4,3]{0,1}, s32[], f32[]) + conditional(pred.1, tuple.3, tuple.4), true_computation=branch1, + false_computation=branch2 + get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0 + get-first-index = s32[] get-tuple-element(conditional), index=1 + get-second-index = f32[] get-tuple-element(conditional), index=2 + copy.3 = f32[4,3]{1,0} copy(get-zero-index) + ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index, + get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + VLOG(1) << module->ToString(); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 8); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2), + op::GetTupleElement(op::Conditional(), 0), + op::GetTupleElement(op::Conditional(), 1)))); +} + +TEST_F(ConditionalCodeMotionTest, MoveReplicatedTupleEntryOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.1 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + convert.1 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.1) + all-reduce.3 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + convert.3 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.3) + ROOT tuple.1 = (f32[3,3,128,128], f32[3,3,128,128]) tuple(convert.1, convert.3) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.2), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181 + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2) + ROOT tuple.2 = (f32[3,3,128,128], f32[3,3,128,128]) tuple(convert.2, convert.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + conditional = (f32[3,3,128,128], f32[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = f32[3,3,128,128] + get-tuple-element(conditional), index=0 + add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index) + ROOT result = (f32[3,3,128,128]) tuple(add.1) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple(op::Add( + op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))), + op::Convert( + op::AllReduce(op::GetTupleElement(op::Conditional()))))))); +} + } // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 323bf44dcd3..f5506b894fd 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -300,7 +300,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { window_dim->set_window_dilation(1); HloInstruction* new_convolution = MakeConvolveHlo(activation, filter, convolution->feature_group_count(), - window, dim_numbers, convolution->precision_config()) + /*batch_group_count=*/1, window, dim_numbers, + convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); TF_CHECK_OK(computation_->ReplaceInstruction( @@ -649,7 +650,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); HloInstruction* new_convolution = - MakeConvolveHlo(activation, filter, 1, window, dim_numbers, + MakeConvolveHlo(activation, filter, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dim_numbers, convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b88120d8128..f2e37ca23b6 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -362,6 +362,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies +// will remove the unnecessary copies. +Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis, + HloInstruction* in_place_op, + int64 operand_number) { + VLOG(2) << "Adding copies for in-place operation " << in_place_op->name(); + HloInstruction* operand = in_place_op->mutable_operand(operand_number); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + in_place_op->parent()->DeepCopyInstruction(operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy)); + return Status::OK(); +} + // Conservatively adds copies before root instruction of entry computation and // each aliased parameter to resolve interference of aliased input and output // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary @@ -509,6 +522,12 @@ class CopyRemover { // value. The map is used to construct the copy info map below. absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { + // No copies should have been inserted within fused computations, so no + // need to remove them. HloOrdering isn't compatible with HloValues inside + // fusions, so skip copy removal for them. + if (buffer.values().at(0)->defining_instruction()->IsFused()) { + continue; + } // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate // interference. Specifically, the addition of the control flow edges @@ -591,7 +610,7 @@ class CopyRemover { void CreateCopyMap( const HloModule& module, const absl::flat_hash_map& value_to_node) { - for (HloComputation* computation : module.computations()) { + for (HloComputation* computation : module.MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with // ambiguous sources are not removable. @@ -1005,7 +1024,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); - for (HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { @@ -1013,6 +1032,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + const HloUse& operand = operand_and_output_index.first; + CHECK_EQ(operand.operand_index, ShapeIndex{}) + << "Support for non-{} shape operand not currently implemented."; + TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation( + *alias_analysis, instruction, operand.operand_number)); + } } } } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3ee6b200da5..78730cbdcb8 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2530,5 +2530,250 @@ ENTRY Entry { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add = f32[1280,1,128] add(negate, negate) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation.1 { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +fused_computation.2 { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1 + ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) { + // Tests multi-output fusion with two DUS outputs, requiring two copies. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(negate1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) { + // Same as above, but negate1 is not used beyond fusion, so it only needs one + // copy for negate0. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(gte1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6eaf43902fe..4e25d667d03 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -130,21 +130,24 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", - "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:VectorOps", "//tensorflow/compiler/xla/service:copy_insertion", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:topk_rewriter", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:tree_reduction_rewriter", - "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:slow_operation_alarm", "//tensorflow/compiler/xla/service:scatter_expander", + "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:literal", @@ -183,6 +186,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:llvm_compiler", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:sort_simplifier", @@ -197,7 +201,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm-project//llvm:Core", - "@llvm-project//llvm:MC", "@llvm-project//llvm:Object", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 0826d7b8ce1..e6c72e60636 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,7 +42,12 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" @@ -54,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/conditional_to_select.h" @@ -77,13 +83,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -120,6 +126,21 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +namespace { + +// We need to explicitly load all the dialects we will involved in emitting the +// IR. This is only needed because of how MLIR is bolted into XLA and does not +// make use of the MLIR infrastructure (like using a proper pass pipeline). +// Hopefully this will all go away at some point in favor of a better +// integration. +void LoadMLIRDialects(mlir::MLIRContext& context) { + context.loadDialect(); +} + +} // namespace + namespace xla { namespace cpu { using BufferInfo = cpu_function_runtime::BufferInfo; @@ -163,8 +184,6 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); - - mlir::registerAllDialects(); } namespace { @@ -260,6 +279,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -288,8 +308,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*expansion_type=*/LogisticExpansionType::kExp); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(target_machine_features); { auto& pass = @@ -303,6 +322,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(options); pass.AddPass(); pass.AddPass(); + pass.AddPass(GatherExpander::kEliminateSimpleGathers); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -620,10 +640,10 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - auto llvm_module = absl::make_unique( - "__compute_module", - mlir_context.getRegisteredDialect() - ->getLLVMContext()); + LoadMLIRDialects(mlir_context); + llvm::LLVMContext llvm_context; + auto llvm_module = + absl::make_unique("__compute_module", llvm_context); auto jit = absl::make_unique( CompilerTargetOptions(module->config()), @@ -832,10 +852,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - llvm::Module llvm_module( - "__compute_module", - mlir_context.getRegisteredDialect() - ->getLLVMContext()); + LoadMLIRDialects(mlir_context); + llvm::LLVMContext llvm_context; + llvm::Module llvm_module("__compute_module", llvm_context); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 0abcc91a1d7..7431e829b8e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -247,6 +247,12 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( ExecutionInput& input = arguments[alias->parameter_number]; MaybeOwningDeviceMemory* maybe_owning_memory = input.MutableBuffer(alias->parameter_index); + if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) { + return InvalidArgument( + "An input was configured to be must-alias at " + "compile time but not donated at runtime: %s", + alias->ToString()); + } if (absl::optional owning = maybe_owning_memory->Release()) { // If the caller passes the ownership of the device memory, reuse it diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 9460cc55e10..42c6c9839bf 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -95,7 +95,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && - consumer->ReusesOperandElements(operand_index)) { + ReusesOperandElements(consumer, operand_index)) { VLOG(2) << "Fusion is not profitable."; return false; } @@ -132,7 +132,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( + if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 72f4d5369c8..36566d6c25f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1640,7 +1640,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( if (current_size_fragment >= vector_register_size_in_elements) { auto vector_type = llvm::VectorType::get( - element_ir_type, vector_register_size_in_elements); + element_ir_type, vector_register_size_in_elements, false); sharded_vector_type.insert( sharded_vector_type.end(), current_size_fragment / vector_register_size_in_elements, @@ -1656,7 +1656,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // of two are all legal vector sizes (or at least can be lowered easily by // LLVM). sharded_vector_type.push_back( - llvm::VectorType::get(element_ir_type, current_size_fragment)); + llvm::VectorType::get(element_ir_type, current_size_fragment, false)); } return sharded_vector_type; } @@ -2412,11 +2412,14 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0)); llvm::Value* out_indices_ptr = EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1)); - EmitCallToFunc(runtime::kTopKF32SymbolName, - {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1), - b_.getInt64(input->shape().dimensions().back()), - b_.getInt64(k), values_ptr, out_values_ptr, out_indices_ptr}, - b_.getVoidTy()); + EmitCallToFunc( + runtime::kTopKF32SymbolName, + {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1), + b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k), + BitCast(values_ptr, b_.getFloatTy()->getPointerTo()), + BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()), + BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())}, + b_.getVoidTy()); llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr}, &b_); diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 8d9229c1223..3afdd9c163e 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -115,7 +115,7 @@ void RewriteCalls( // Upcast to vector type if input is a scalar. if (vector_width == 1) { - llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1, false); input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, uint64_t{0}); } @@ -264,8 +264,8 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, z = vsl.Add(one, z); // Convert n' to an i32. This is safe because we clamped it above. - llvm::Value* n_i32 = - b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width)); + llvm::Value* n_i32 = b->CreateFPToSI( + n, llvm::VectorType::get(b->getInt32Ty(), vector_width, false)); auto splat_i32 = [&](int32 v) { return b->CreateVectorSplat(vector_width, b->getInt32(v)); @@ -329,7 +329,7 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, llvm::Value* vector_constant_23 = b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b->getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width, false); llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), vector_constant_23); diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc index ff48f554ce6..ae23f224207 100644 --- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -32,7 +32,8 @@ namespace cpu { namespace { // Lower an MLIR module to an LLVM module. -std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { +std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module, + llvm::LLVMContext *context) { // When set, the LLVM backend will be allowed to reassociate floating-point // reductions, which enables much more efficient "horizontal" SIMD // implementations. @@ -47,7 +48,7 @@ std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { mlir::LowerVectorToLLVMOptions().setReassociateFPReductions( kReassociateFPReductions))); CHECK(succeeded(manager.run(*module))); - return mlir::translateModuleToLLVMIR(*module); + return mlir::translateModuleToLLVMIR(*module, *context); } // Get arguments to pass a memref to an mlir function. @@ -114,7 +115,8 @@ Status EmitMlirFuncAndCall( emitter(&op_builder, function); // Now link it all into the main LLVM module. - auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + auto mlir_llvm_module = + MakeLLVMModule(std::move(mlir_module), &b->getContext()); mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); llvm::Linker::linkModules( *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 225102e6ae6..48f2248d2d7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -143,7 +143,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || - instruction->shape().IsTuple() || opcode == HloOpcode::kRng) { + instruction->shape().IsTuple() || opcode == HloOpcode::kRng || + opcode == HloOpcode::kConstant) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index e22210a61f2..5b454379876 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -191,5 +191,19 @@ TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, ConstantNotParallelized) { + constexpr char hlo_string[] = R"( + HloModule TestTaskParallel_constant + ENTRY const { + ROOT constant = f32[1234567] constant({...}) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 0d2eab9fd42..48aa32f6b8f 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -33,7 +33,7 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, scalar_type_ = llvm_ir::PrimitiveTypeToIrType( primitive_type, b_->GetInsertBlock()->getModule()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); - vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); } @@ -155,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); if (vector) { - return llvm::VectorType::get(scalar_int_type, vector_size()); + return llvm::VectorType::get(scalar_int_type, vector_size(), false); } else { return scalar_int_type; } diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index f1a0b0a4406..cbed232897f 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -276,7 +276,7 @@ class VectorSupportLibrary { llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); if (llvm::isa(type)) { return llvm::ConstantVector::getSplat( - llvm::ElementCount(vector_size(), /*Scalable=*/false), scalar_value); + llvm::ElementCount::getFixed(vector_size()), scalar_value); } return scalar_value; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index b0def1a2dd8..60d832a940a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -245,6 +245,7 @@ class DfsHloVisitorBase { virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; virtual Status HandleReshape(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; virtual Status HandleParameter(HloInstructionPtr hlo) = 0; virtual Status HandleFusion(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index b1d674fe467..3d1a9a3c894 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -198,6 +198,9 @@ class DfsHloVisitorWithDefaultBase Status HandlePad(HloInstructionPtr pad) override { return DefaultAction(pad); } + Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override { + return DefaultAction(dynamic_reshape); + } Status HandleReshape(HloInstructionPtr reshape) override { return DefaultAction(reshape); } diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc index 4670ce6940a..3adde5f7d48 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -49,14 +49,11 @@ bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) { return false; } -/* static */ absl::optional -ParseDotGeneralFromConvolution(const HloInstruction* conv) { +/* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo( + const HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) { - return absl::nullopt; - } const auto& conv_dims = conv->convolution_dimension_numbers(); - DotGeneralAsConvolutionDimsInfo dims; + DotConvolutionDimsInfo dims; dims.lhs_non_contracting_dims.push_back( {conv_dims.input_batch_dimension(), -1, conv_dims.output_batch_dimension(), -1}); @@ -98,10 +95,10 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { // padding N - 1, high padding N - 1 and window reversal. dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i}); } else { - return absl::nullopt; + dims.conv_spatial_dims.push_back({lhs, rhs, output, i}); } } else { - return absl::nullopt; + dims.conv_spatial_dims.push_back({lhs, rhs, output, i}); } } @@ -110,8 +107,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { StatusOr> CreateShardedConvForDotGeneralConvolution( - const HloInstruction& conv, - const DotGeneralAsConvolutionDimsInfo& dot_dnums, + const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) { CHECK_EQ(conv.opcode(), HloOpcode::kConvolution); const auto& conv_dnums = conv.convolution_dimension_numbers(); @@ -141,16 +137,66 @@ CreateShardedConvForDotGeneralConvolution( wd->set_padding_high(wd->size() - 1); wd->set_padding_low(wd->size() - 1); } - TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, conv_dnums)); + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + /*feature_group_count=*/conv.feature_group_count(), + /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums)); *sharded_conv_shape.mutable_layout() = conv.shape().layout(); return HloInstruction::CreateConvolve( sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, conv_dnums, conv.precision_config()); + /*feature_group_count=*/conv.feature_group_count(), + /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums, + conv.precision_config()); +} + +DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) { + const auto& dot_dim_numbs = dot->dot_dimension_numbers(); + dot_as_convolution_util::DotConvolutionDimsInfo dnums; + for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) { + dnums.batch_dims.emplace_back(); + dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i); + dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i); + dnums.batch_dims.back().output = i; + dnums.batch_dims.back().spatial_dim = -1; + } + for (int64 i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size(); + ++i) { + dnums.contracting_dims.emplace_back(); + dnums.contracting_dims.back().lhs = + dot_dim_numbs.lhs_contracting_dimensions(i); + dnums.contracting_dims.back().rhs = + dot_dim_numbs.rhs_contracting_dimensions(i); + dnums.contracting_dims.back().output = -1; + dnums.contracting_dims.back().spatial_dim = -1; + } + for (int64 i = 0; i < dot->operand(0)->shape().rank(); ++i) { + if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) && + !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) { + dnums.lhs_non_contracting_dims.emplace_back(); + dnums.lhs_non_contracting_dims.back().lhs = i; + dnums.lhs_non_contracting_dims.back().rhs = -1; + dnums.lhs_non_contracting_dims.back().output = + dot_dim_numbs.lhs_batch_dimensions_size() + + dnums.lhs_non_contracting_dims.size() - 1; + dnums.lhs_non_contracting_dims.back().spatial_dim = -1; + } + } + for (int64 i = 0; i < dot->operand(1)->shape().rank(); ++i) { + if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) && + !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) { + dnums.rhs_non_contracting_dims.emplace_back(); + dnums.rhs_non_contracting_dims.back().lhs = -1; + dnums.rhs_non_contracting_dims.back().rhs = i; + dnums.rhs_non_contracting_dims.back().output = + dot_dim_numbs.lhs_batch_dimensions_size() + + dnums.lhs_non_contracting_dims.size() + + dnums.rhs_non_contracting_dims.size() - 1; + dnums.rhs_non_contracting_dims.back().spatial_dim = -1; + } + } + return dnums; } } // namespace dot_as_convolution_util diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h index 6a7cacf812d..16a542208d2 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.h +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h @@ -25,8 +25,9 @@ limitations under the License. namespace xla { namespace dot_as_convolution_util { -// Describes the dimensions of a convolution that can be interpreted as a dot. -struct DotGeneralAsConvolutionDimsInfo { +// Describes the dimensions of a convolution that can be interpreted as a dot +// or a normal convolution. +struct DotConvolutionDimsInfo { // The dimension numbers for the operands and output corresponding to a // logical dimension (e.g., batch, contracting, non-contracting). If an // operand or the output doesn't have the logical dimension, it is set to @@ -43,23 +44,22 @@ struct DotGeneralAsConvolutionDimsInfo { std::vector contracting_dims; std::vector lhs_non_contracting_dims; std::vector rhs_non_contracting_dims; + std::vector conv_spatial_dims; }; -// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can -// be interpreted as a dot, or absl::nullopt otherwise. -absl::optional ParseDotGeneralFromConvolution( - const HloInstruction* conv); +// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can +// be interpreted as a dot, there is no conv_spatial_dims. +DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv); // Creates sharded convolution instruction that can be interpreted as a dot. // This is a utility for per-op partitioners. // - 'conv' is the original convolution instruction. -// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'. +// - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'. // - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result // convolution instruction. StatusOr> CreateShardedConvForDotGeneralConvolution( - const HloInstruction& conv, - const DotGeneralAsConvolutionDimsInfo& dot_dnums, + const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); // Check if a spatial dim is parallel batch dimension. @@ -68,6 +68,10 @@ CreateShardedConvForDotGeneralConvolution( // dilation B. bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size); +// Returns a DotConvolutionDimsInfo from a kDot instruction, where all +// the spatial_dim values are set to -1. +DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot); + } // namespace dot_as_convolution_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 2f2456863e9..80f98775c01 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -97,6 +97,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleTranspose(HloInstruction* hlo) override; + Status HandleDynamicReshape(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; Status HandleSort(HloInstruction* hlo) override; @@ -621,6 +623,18 @@ Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) { return PassThroughDynamicDimension(hlo); } +Status DynamicDimensionInferenceVisitor::HandleDynamicReshape( + HloInstruction* hlo) { + HloDynamicReshapeInstruction* dynamic_reshape = + Cast(hlo); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->shape().is_dynamic_dimension(i)) { + parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i)); + } + } + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, @@ -805,7 +819,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { } if (input_dim_size > output_dim_size) { - TF_RET_CHECK(input_dim_size % output_dim_size == 0); + TF_RET_CHECK(input_dim_size % output_dim_size == 0) + << reshape->ToString(); const int64 divisor = input_dim_size / output_dim_size; HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index b5a17619edf..69f64c31a2f 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -1248,5 +1248,34 @@ TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) { EXPECT_TRUE(handler_called); } +TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {9}), "data_input")); + auto six = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(6))); + // Creates an input of shape [<=9], dynamic size is 6. + auto dynamic_input = + builder.AddInstruction(HloInstruction::CreateSetDimensionSize( + ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0)); + auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(S32, {}), "size_param")); + auto three = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + + // Reshape [<=9] into [3, <=3] + + auto dynamic_reshape = + builder.AddInstruction(HloInstruction::CreateDynamicReshape( + ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input, + {three, dynamic_size})); + + module_->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index c1f9da599e8..b4c56113239 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -32,6 +32,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -125,6 +127,74 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, } } +StatusOr ReplaceGetSize( + HloInstruction* instr, + DynamicDimensionInference* dynamic_dimension_inference) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) + << "instr->shape() " << instr->shape().ToString() << " , " + << "legal_shape " << legal_shape.ToString(); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + // The dependency between a instruction and its dynamic dimensions is not + // modeled in the IR. As instr is being replaced by dynamic_size, also tell + // dynamic dimension inference that the instruction is being replaced. + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( + instr, dynamic_size); + } else { + int32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, + new_instr); + } + return true; +} + +StatusOr ReplaceSetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kSetDimensionSize) { + return false; + } + + TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( + instr->shape(), instr->operand(0)->shape())) + << "instr->shape() " << instr->shape().ToString() << " , " + << "instruction operand shape " << instr->operand(0)->shape(); + HloInstruction* operand = instr->mutable_operand(0); + + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + return true; +} + +StatusOr ReplaceSetBound(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kCustomCall || + instr->custom_call_target() != "SetBound") { + return false; + } + + TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( + instr->shape(), instr->operand(0)->shape())) + << "instr->shape() " << instr->shape().ToString() << " , " + << "instruction operand shape " << instr->operand(0)->shape(); + HloInstruction* operand = instr->mutable_operand(0); + + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + return true; +} + bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, int64 dimension) { if ((inst->opcode() == HloOpcode::kReduceWindow || @@ -1236,6 +1306,18 @@ StatusOr DynamicPadder::Run(HloModule* module) { changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); continue; } + + if (inst->opcode() == HloOpcode::kDynamicReshape) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); + auto* static_reshape = + computation->AddInstruction(HloInstruction::CreateReshape( + inst->shape(), inst->mutable_operand(0))); + TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape)); + TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize( + inst, static_reshape, {})); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); @@ -1292,6 +1374,25 @@ StatusOr DynamicPadder::Run(HloModule* module) { /*require_dynamic_output=*/require_dynamic_output)); } + for (auto* computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN( + bool replaced_get_size, + ReplaceGetSize(instruction, &dynamic_dimension_inference)); + changed = changed || replaced_get_size; + } + } + + for (auto* computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced_set_bound, + ReplaceSetBound(instruction)); + changed = changed || replaced_set_size; + changed = changed || replaced_set_bound; + } + } + HloDCE dce; TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); VLOG(2) << "Post DynamicPadder HLO:"; diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index e8f429d9db6..3855531a97b 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -380,10 +379,15 @@ class ExecutionTest : public HloTestBase { Literal PadAndExecute(std::unique_ptr module, absl::Span arguments, bool slice_dynamic_output = true) { + if (!slice_dynamic_output) { + auto new_config = module->config(); + new_config.mutable_entry_computation_layout() + ->mutable_result_layout() + ->ClearDynamicShape(); + module->set_config(new_config); + } DynamicPadder padder(slice_dynamic_output); TF_CHECK_OK(padder.Run(module.get()).status()); - HloGetDimensionSizeRewriter rewriter; - TF_CHECK_OK(rewriter.Run(module.get()).status()); HloDCE dce; TF_CHECK_OK(dce.Run(module.get()).status()); return ExecuteAndTransfer(std::move(module), arguments); @@ -1179,6 +1183,84 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +ENTRY main { + param = s32[2, 3, 3] parameter(0) + size = s32[] constant(2) + param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size), + dimensions={1} + param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size), + dimensions={2} + result_size = s32[] constant(8) + ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size) +} +)"; + + // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0. + Literal operand = LiteralUtil::CreateR3( + {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}, false); + result.SetDynamicSize(0, 8); + // Padded data looks like this (P is padding which is ignored). + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // Reshaping (with correct reshape rewriting) produces: + // [0, 1, 3, 4, 0, 1, 3, 4] + Literal expected = LiteralUtil::CreateR1({0, 1, 3, 4, 0, 1, 3, 4}); + + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +ENTRY main { + param = s32[18] parameter(0) + eight = s32[] constant(8) + param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0} + two = s32[] constant(2) + // every dimension has dynamic size two. + ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two) +} +)"; + Literal operand = LiteralUtil::CreateR1( + {0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}); + + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}, false); + + result.SetDynamicSize(1, 2); + result.SetDynamicSize(2, 2); + // Padded operand is: + // [0, 1, 3, 4, 0, 1, 3, 4, P, P ....] + // + // Reshaping it should produce: + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + Literal expected = + LiteralUtil::CreateR3({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, SetGetDimensionSize) { const string hlo_text = R"( HloModule TensorFlowScatterV1 @@ -1371,5 +1453,70 @@ ENTRY main { EXPECT_EQ(result, expected); } +namespace op = xla::testing::opcode_matchers; + +class HloDimensionSizeLegalizerTest : public HloTestBase { + protected: + HloDimensionSizeLegalizerTest() {} +}; + +TEST_F(HloDimensionSizeLegalizerTest, Ok) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + size1 = s32[] get-dimension-size(p), dimensions={1} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + p_copy = s32[3,4] copy(p) + p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} + size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloDimensionSizeLegalizerTest, IllegalType) { + auto module = ParseAndReturnUnverifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) { + auto module = ParseAndReturnUnverifiedModule(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = s32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index 75d39298aa3..ab6a3d01d21 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -27,33 +27,25 @@ namespace xla { FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( const HloInstruction* fusion) : fusion_(fusion) { - total_emitted_instructions_ = 0; HloInstruction* root = fusion->fused_expression_root(); indexing_users_[root].insert(fusion); index_usage_count_[fusion] = 1; RecomputeCache(); } -bool FusionNodeIndexingEvaluation::AverageCodeDuplicationTooHigh( +bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( const HloInstruction* producer) const { // This constant is arbitrarily chosen. Essentially we don't want to have too // much code duplication, because it slows down the compilation time. There is // a tradeoff between compilation time and runtime here. const int64 kAllowedCodeDuplication = 15; - // index_usage_count_ contains an entry for each instruction in the fusion - // computation (except parameter instructions), plus an entry for the 'fusion' - // instruction. So the size of this map is already one bigger than the number - // of instructions in the fusion node that are emitted, thus accounting for - // the number of instructions after 'producer' is fused. - return EvaluateTotalEmittedInstructions(producer) / - index_usage_count_.size() > - kAllowedCodeDuplication; + return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; } -int64 FusionNodeIndexingEvaluation::EvaluateTotalEmittedInstructions( +int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions( const HloInstruction* producer) const { - int64 total = total_emitted_instructions_; + int64 total = 0; for (const auto* user : indexing_users_.at(producer)) { total += index_usage_count_.at(user); } @@ -96,19 +88,9 @@ void FusionNodeIndexingEvaluation::UpdateIndexUsageCount( const HloInstruction* instruction) { int64 total = 0; for (const auto* user : indexing_users_[instruction]) { - int64 weight = 1; - // Concatenate is special: the index differs for each operand, so - // in the worst case we have to deal with as many index values as - // the number of operands of Concatenate. By considering the worst - // case, we are more conservative than necessary regarding - // counting the index usage. - if (user->opcode() == HloOpcode::kConcatenate) { - weight = user->operand_count(); - } - total += index_usage_count_.at(user) * weight; + total += index_usage_count_.at(user); } CHECK(index_usage_count_.emplace(instruction, total).second); - total_emitted_instructions_ += total; } void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands( diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h index 9630986d188..b85bf9104c7 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h @@ -26,17 +26,14 @@ class FusionNodeIndexingEvaluation { public: explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion); - // Evaluate the average number of times an instruction is emitted inside the - // fusion node, if 'producer' is fused into 'fusion_'. If this average - // duplication is "too high" (some arbitrary chosen constant), returns - // true. - bool AverageCodeDuplicationTooHigh(const HloInstruction* producer) const; + // Evaluate the number of times 'producer' would be emitted if it is fused + // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen + // constant), returns true. + bool CodeDuplicationTooHigh(const HloInstruction* producer) const; - // Evaluate the total number of times an instruction is emitted inside the - // fusion node, if 'producer' is fused into 'fusion_'. An instruction may be - // emitted several times, once for each different index value with which it is - // indexed. - int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer) const; + // Evaluate the number of times 'producer' would be emitted if it is fused + // into 'fusion_'. + int64 EvaluateEmittedInstructions(const HloInstruction* producer) const; // Update the evaluation cache after having fused 'producer' into 'fusion_'. // 'producer' is the cloned instruction which is now part of the fusion @@ -84,9 +81,6 @@ class FusionNodeIndexingEvaluation { // The fusion instruction. const HloInstruction* fusion_; - - // The total number of emitted instructions. - int64 total_emitted_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc index b20f52d2d62..b00abdc9abf 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation_test.cc @@ -29,7 +29,7 @@ using FusionNodeIndexingEvaluationTest = HloTestBase; // Subclass of InstructionFusion exposing the protected methods Fuse and // FuseInstruction for testing. Also adds the FusionNodeIndexingEvaluation to -// track the average code duplication due to indexing HloInstructions with +// track the code duplication due to indexing HloInstructions with // different index values. class InstructionFusionForTesting : public InstructionFusion { public: @@ -61,8 +61,8 @@ class InstructionFusionForTesting : public InstructionFusion { return InstructionFusion::Fuse(producer, consumer); } - int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer, - const HloInstruction* consumer) { + int64 EvaluateEmittedInstructions(const HloInstruction* producer, + const HloInstruction* consumer) { if (consumer->opcode() != HloOpcode::kFusion) { return 0; } @@ -71,8 +71,8 @@ class InstructionFusionForTesting : public InstructionFusion { fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - return fusion_node_evaluations_.at(consumer) - .EvaluateTotalEmittedInstructions(producer); + return fusion_node_evaluations_.at(consumer).EvaluateEmittedInstructions( + producer); } private: @@ -109,8 +109,7 @@ TEST_F(FusionNodeIndexingEvaluationTest, FuseThreeInstructions) { HloInstruction* slice1 = sub->mutable_operand(0); HloInstruction* slice2 = sub->mutable_operand(1); auto fusion = instruction_fusion.Fuse(slice1, sub); - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(slice2, fusion), - 3); + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2, fusion), 1); instruction_fusion.Fuse(slice2, fusion); } @@ -151,37 +150,31 @@ TEST_F(FusionNodeIndexingEvaluationTest, ExponentialDuplicationPattern) { HloInstruction* slice2_1 = add2->mutable_operand(1); auto fusion = instruction_fusion.Fuse(slice2_0, add2); // So far we have fused add2 and slice2.0. So when we also fuse slice2.1, we - // expect to emit 3 instructions. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice2_1, fusion), 3); + // expect to emit it 1 time. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2_1, fusion), + 1); instruction_fusion.Fuse(slice2_1, fusion); HloInstruction* add1 = fusion->mutable_operand(0); EXPECT_EQ(add1->opcode(), HloOpcode::kAdd); - // If we fuse add1 into 'fusion', it needs to be emitted twice, adding 2 to - // the sum. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add1, fusion), - 5); + // If we fuse add1 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add1, fusion), 2); instruction_fusion.Fuse(add1, fusion); HloInstruction* slice1_0 = fusion->mutable_operand(0); EXPECT_EQ(slice1_0->opcode(), HloOpcode::kSlice); - // If we fuse slice1.0 into 'fusion', it needs to be emitted twice, adding 2 - // to the sum. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice1_0, fusion), 7); + // If we fuse slice1.0 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_0, fusion), + 2); instruction_fusion.Fuse(slice1_0, fusion); HloInstruction* slice1_1 = fusion->mutable_operand(0); EXPECT_EQ(slice1_1->opcode(), HloOpcode::kSlice); - // If we fuse slice1.1 into 'fusion', it needs to be emitted twice, adding 2 - // to the sum. - EXPECT_EQ( - instruction_fusion.EvaluateTotalEmittedInstructions(slice1_1, fusion), 9); + // If we fuse slice1.1 into 'fusion', it needs to be emitted twice. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_1, fusion), + 2); instruction_fusion.Fuse(slice1_1, fusion); HloInstruction* add0 = fusion->mutable_operand(0); EXPECT_EQ(add0->opcode(), HloOpcode::kAdd); - // If we fuse add0 into 'fusion', it needs to be emitted twice, adding 4 to - // the sum. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion), - 13); + // If we fuse add0 into 'fusion', it needs to be emitted four times. + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4); instruction_fusion.Fuse(add0, fusion); } @@ -212,10 +205,9 @@ ENTRY entry_computation { HloInstruction* add0 = fusion->mutable_operand(0); EXPECT_EQ(add0->opcode(), HloOpcode::kAdd); // Here, the cache for the fusion node needs to be recomputed. Make sure we - // still get the same evaluation as before when we incrementally built the + // still get the same evaluation as before when we incrementally build the // cache. - EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion), - 13); + EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 1838f65e6ea..d38873a501d 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -269,6 +269,22 @@ static StatusOr PermuteBatchAndOffsetDims( return MakeTransposeHlo(accumulator, permutation); } +// Computes how many trips a loop implementing this gather op would take. +static int64 GatherLoopTripCount(HloInstruction* gather_instr) { + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 trip_count = 1; + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + trip_count *= start_indices_shape.dimensions(i); + } + } + return trip_count; +} + // High Level Algorithm // // We follow the following steps in sequence: @@ -311,20 +327,13 @@ StatusOr GatherExpander::ExpandInstruction( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* start_indices = gather_instr->mutable_operand(1); - const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); const GatherDimensionNumbers& dim_numbers = gather_instr->gather_dimension_numbers(); - int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= start_indices_shape.dimensions(i); - } - } - + int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr); if (!IsInt32(gather_loop_trip_count)) { return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " @@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()) && + // In kEliminateSimpleGathers mode, we only simplify instructions + // which can be represented without a loop -- i.e. we only simplify + // gathers which have a trip count of 1. + (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 5625a37cb46..e665fcd713c 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -21,10 +21,30 @@ limitations under the License. namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic -// slices. This lets backends that don't support gather directly to -// nevertheless have a minimum level of support. +// slices. +// +// This pass can be used two ways: +// +// - kEliminateAllGathers: For backends that don't support gather, this pass +// can convert every gather to a loop. +// +// - kEliminateSimpleGathers: For backends that *do* support gather, this pass +// can strength-reduce "simple" gathers -- specifically, gathers that can be +// represented without a loop -- to dyanmic-slices. +// +// Note that even in kEliminateSimpleGathers mode, this pass may still expand a +// gather into a loop (with a trip-count of 1). It's up to other simplification +// passes to remove the loop. +// class GatherExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllGathers, + kEliminateSimpleGathers, + }; + + explicit GatherExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "gather_expander"; } protected: @@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* gather_inst) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 706327091d9..4b0808e9aaf 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -42,7 +43,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - Status status = GatherExpander{}.Run(module.get()).status(); + Status status = GatherExpander{GatherExpander::kEliminateAllGathers} + .Run(module.get()) + .status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); ASSERT_THAT( @@ -68,7 +71,9 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -129,7 +134,9 @@ ENTRY main { OpMetadata metadata; metadata.set_op_name("Gather"); module->entry_computation()->root_instruction()->set_metadata(metadata); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -147,5 +154,54 @@ ENTRY main { "after gather expansion"; EXPECT_EQ(while_instr->metadata().op_name(), "Gather"); } + +TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1, 3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateSimpleGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_FALSE(changed); +} + +TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) { + const string hlo_text = R"( +HloModule test + +ENTRY main { + operand = s32[100] parameter(0) + indices = s32[1] parameter(1) + ROOT gather = s32[10] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=0, + slice_sizes={10} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateAllGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_TRUE(changed); + ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(), + {HloOpcode::kGather})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 8dfd73e9a6a..c861ceffc05 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -254,11 +254,18 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -286,10 +293,13 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -1025,6 +1035,24 @@ cc_library( ], ) +tf_cc_test( + name = "gpu_conv_padding_legalization_test", + srcs = ["gpu_conv_padding_legalization_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_conv_padding_legalization", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) + cc_library( name = "cudnn_pad_for_convolutions", srcs = ["cudnn_pad_for_convolutions.cc"], @@ -1144,6 +1172,7 @@ cc_library( ":gpu_sanitize_constant_names", ":gpu_scatter_expander", ":horizontal_fusion", + ":horizontal_input_fusion", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -1158,6 +1187,7 @@ cc_library( ":target_constants", ":tree_reduction_rewriter", ":variadic_op_splitter", + "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1168,6 +1198,7 @@ cc_library( "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_4d_expander", @@ -1177,13 +1208,13 @@ cc_library( "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto_util", @@ -1214,6 +1245,8 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", ], ) @@ -1480,6 +1513,7 @@ cc_library( hdrs = ["stream_executor_util.h"], copts = tf_copts(), deps = [ + ":launch_dimensions", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1718,6 +1752,7 @@ cc_library( srcs = ["horizontal_fusion.cc"], hdrs = ["horizontal_fusion.h"], deps = [ + ":gpu_fusible", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", @@ -1754,6 +1789,45 @@ tf_cc_test( ], ) +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + ":gpu_fusible", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":horizontal_input_fusion", + ":multi_output_fusion", + "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "reduction_degenerate_dim_remover", srcs = ["reduction_degenerate_dim_remover.cc"], @@ -1909,16 +1983,3 @@ cc_library( "@llvm-project//mlir:LLVMDialect", ], ) - -# Library with XLA thunks dialect static initialization. -cc_library( - name = "xla_thunks_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":xla_thunks_ops", - "@llvm-project//mlir:IR", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 9b192aaa8e1..10a565308de 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -613,10 +613,13 @@ static StatusOr DeviceCompare(se::Stream* stream, LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); - stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), - se::BlockDim(dim.block_count()), *comparison_kernel, - lhs_typed, rhs_typed, static_cast(kTolerance), - buffer_size, out_param.cref()); + LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dim.block_counts(); + stream->ThenLaunch( + se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), + *comparison_kernel, lhs_typed, rhs_typed, static_cast(kTolerance), + buffer_size, out_param.cref()); uint64 result = -1; CHECK_EQ(out_param->size(), sizeof(result)); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 60e4cb84b09..fa066e9d320 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -201,8 +201,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Merging into all users enables the removal of 'fusion' from the // computation. if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { - return user->opcode() == HloOpcode::kFusion && - IsProducerConsumerFusible(*fusion, *user); + return IsProducerConsumerFusible(*fusion, *user); })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; @@ -230,18 +229,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // This is done to avoid the duplication of expensive instructions, which // would occur if 'fusion' were merged into multiple users. // - // If 'fusion' has just one user, then an earlier fusion pass chose not to - // fuse this producer/consumer pair (likely because of expensive instruction - // re-use by the consumer), and so we honor that choice here as well. - // - // Moreover, if we are going to save a "lot" in memory bandwidth then we + // However, if we are going to save a "lot" in memory bandwidth then we // ignore how expensive the fusion instructions are. The heuristic used to // determine "a lot" is the following: merging must reduce memory traffic by a // factor of 0.3, and the amount of memory accessed must not be entirely // trivial (above 1K). This likely has room for improvement in the future. bool allow_expensive_ops = - merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024; + fusion->user_count() == 1 || + (merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024); if (!allow_expensive_ops && absl::c_any_of(fusion->fused_instructions(), @@ -286,7 +282,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Merge fused instructions from 'fusion' into each user. std::vector users = fusion->users(); for (HloInstruction* user : users) { - user->MergeFusionInstruction(fusion); + if (user->opcode() == HloOpcode::kFusion) { + user->MergeFusionInstruction(fusion); + } else { + HloInstruction* fused_user = + computation_->AddInstruction(HloInstruction::CreateFusion( + user->shape(), ChooseFusionKind(*fusion, *user), user)); + TF_CHECK_OK(computation_->ReplaceInstruction(user, fused_user)); + fused_user->MergeFusionInstruction(fusion); + } changed_ = true; } ++total_merged_; @@ -299,7 +303,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { }) << " }"; // Remove 'fusion' instruction. - CHECK_EQ(0, fusion->user_count()); + CHECK_EQ(0, fusion->user_count()) << fusion->ToString(); return computation_->RemoveInstruction(fusion); } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 42891154c23..d08c732e611 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -234,6 +234,54 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule jit_matmul.36 + + max (parameter.13: f32[], parameter.14: f32[]) -> f32[] { + parameter.13 = f32[] parameter(0) + parameter.14 = f32[] parameter(1) + ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14) + } + + add (parameter.29: f32[], parameter.30: f32[]) -> f32[] { + parameter.29 = f32[] parameter(0) + parameter.30 = f32[] parameter(1) + ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30) + } + + fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] { + param_1.4 = f32[200,200,200]{2,1,0} parameter(0) + param_2.1 = f32[200,200]{1,0} parameter(1) + broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2} + subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3) + exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0) + constant.27 = f32[] constant(0) + ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add + } + + fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] { + param_1.9 = f32[200,200]{1,0} parameter(1) + broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1} + param_0.7 = f32[200,200]{1,0} parameter(0) + broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2} + ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8) + } + + ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] { + parameter.2 = f32[200,200]{1,0} parameter(1) + parameter.1 = f32[200,200]{1,0} parameter(0) + fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3 + constant.11 = f32[] constant(-inf) + reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max + ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1 + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::Fusion(), op::Parameter(), op::Parameter())); +} + TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { auto module = ParseAndReturnVerifiedModule(R"( HloModule m @@ -398,6 +446,29 @@ TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) { EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); } +TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + + %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] { + %p = f32[1024,1024,1024] parameter(0) + ROOT %t = f32[1024,1024,1024] tanh(%p) + } + + %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] { + %p = f32[1024,1024,1024] parameter(0) + ROOT %t = f32[1024,1024,1024] add(%p, %p) + } + + ENTRY entry { + p0 = f32[1024,1024,1024] parameter(0) + f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b + ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f2d29b5d11f..cc4de2c1099 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -29,12 +29,15 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_reduce_combiner.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_4d_expander.h" @@ -43,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -56,6 +60,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -79,7 +84,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -139,6 +143,9 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); + // Comparison total order expander + pipeline.AddPass(); + // Remove zero-sized HLO from the input so that other passes don't have to // handle it. pipeline.AddPass(); @@ -190,11 +197,12 @@ Status GpuCompiler::OptimizeHloModule( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); - // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. - pipeline.AddPass(); + pass.AddPass(); + + pass.AddPass(GatherExpander::kEliminateSimpleGathers); + pass.AddPass(ScatterExpander::kEliminateSimpleScatters); AlgebraicSimplifierOptions options; // When transposes appear in a fusion node, we can easily adjust the @@ -295,11 +303,13 @@ Status GpuCompiler::OptimizeHloModule( HloPassPipeline horizontal_fusion("horizontal_fusion"); horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status()); } + { HloPassPipeline pipeline("all_reduce_combiner"); pipeline.AddPass( @@ -476,7 +486,8 @@ static Status CompileModuleToLlvmIrImpl( int pointer_size, const HloProfileIndexMap* profile_index_map, std::unique_ptr* llvm_module, std::unique_ptr* buffer_assignment, - std::unique_ptr* thunk_schedule) { + std::unique_ptr* thunk_schedule, + std::vector* constants) { *llvm_module = absl::make_unique("", *llvm_context); (*llvm_module)->setTargetTriple(target_triple); @@ -509,15 +520,19 @@ static Status CompileModuleToLlvmIrImpl( DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment, "after_optimizations"); + mlir::MLIRContext mlir_context; + IrEmitterContext ir_emitter_context( hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, - cuda_compute_capability, profile_index_map, llvm_module->get()); + cuda_compute_capability, profile_index_map, &mlir_context, + llvm_module->get()); HloComputation* entry_computation = hlo_module->entry_computation(); - IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, - &ir_emitter_context); - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + TF_ASSIGN_OR_RETURN( + auto ir_emitter, + IrEmitterUnnested::Create(hlo_module->config(), entry_computation, + &ir_emitter_context)); { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); @@ -526,9 +541,10 @@ static Status CompileModuleToLlvmIrImpl( ThunkSequence thunk_sequence; absl::Span order = hlo_schedule->ThunkLaunchOrder(); for (HloInstruction* instruction : order) { - TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter)); - TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction)); - std::unique_ptr thunks = ir_emitter.ConsumeThunkSequence(); + TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get())); + TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction)); + std::unique_ptr thunks = + ir_emitter->ConsumeThunkSequence(); // The invariants between each input HloInstruction* and output Thunk* are // not all explicitly checked, but at least we can document them here: @@ -566,6 +582,10 @@ static Status CompileModuleToLlvmIrImpl( *thunk_schedule = absl::make_unique( std::make_unique(std::move(thunk_sequence)), std::move(stream_assignment), std::move(thunk_to_hlo)); + + if (constants) { + *constants = std::move(ir_emitter_context.constants()); + } } return Status::OK(); @@ -631,12 +651,13 @@ StatusOr> GpuCompiler::RunBackend( std::unique_ptr llvm_module; std::unique_ptr buffer_assignment; std::unique_ptr thunk_schedule; + std::vector constants; TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( module.get(), &llvm_context, target_triple_, data_layout_, stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module, - &buffer_assignment, &thunk_schedule)); + &buffer_assignment, &thunk_schedule, &constants)); if (user_pre_optimization_hook_) { user_pre_optimization_hook_(*llvm_module); @@ -682,7 +703,7 @@ StatusOr> GpuCompiler::RunBackend( backend_result.first, backend_result.second, gpu_version, std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), std::move(profile_printer), - std::move(profile_index_map)); + std::move(profile_index_map), std::move(constants)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -716,7 +737,7 @@ StatusOr> CompileModuleToLlvmIr( hlo_module, llvm_context, target_triple, data_layout, platform_name, gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, pointer_size, /*profile_index_map=*/nullptr, &llvm_module, - &buffer_assignment, &thunk_schedule)); + &buffer_assignment, &thunk_schedule, nullptr)); return llvm_module; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc index 5fa102ac785..94f9a96c0fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc @@ -313,7 +313,11 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( new_backward_conv_window.mutable_dimensions(i)); } // Decreasing the padding by X *increases* the size of our output by X. - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + // Note that we have swapped input spatial dimensions with output spatial + // dimensions to be compatible with the cuDNN API, so + // input_spatial_dimensions(i) gives the i-th spatial dimension of the + // output. + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); new_backward_conv_shape.set_dimensions( dim, new_backward_conv_shape.dimensions(dim) + std::abs(padding_low - padding_high)); @@ -353,7 +357,11 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + // Note that we have swapped input spatial dimensions with output spatial + // dimensions to be compatible with the cuDNN API, so + // input_spatial_dimensions(i) gives the i-th spatial dimension of the + // output. + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); if (padding_low > padding_high) { // If the amount of low padding (of the old backward convolution) is // larger, we internally pad the low end of the activations and slice diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc new file mode 100644 index 00000000000..c214486e18f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using GpuConvPaddingLegalizationTest = HloTestBase; + +TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule convolution_module +ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) { + %operand = f64[2,2,2,3]{3,2,1,0} parameter(0) + %kernel = f64[2,3,2,3]{3,2,1,0} constant( + { + { /*i0=0*/ + { /*i1=0*/ + { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 }, + { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 } + }, + { /*i1=1*/ + { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 }, + { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 } + }, + { /*i1=2*/ + { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 }, + { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 } + } + }, + { /*i0=1*/ + { /*i1=0*/ + { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 }, + { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 } + }, + { /*i1=1*/ + { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 }, + { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 } + }, + { /*i2=2*/ + { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 }, + { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 } + } + } + }) + %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1} + ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" +} + )") + .ValueOrDie(); + ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Tuple(op::Slice(op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)), + op::GetTupleElement())); + auto slice = root->operand(0); + Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4}); + EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape)); + auto conv = slice->operand(0); + Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5}); + EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape)); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 469f2919fba..c963dfb2b2a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -60,14 +60,16 @@ GpuExecutable::GpuExecutable( std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) + std::unique_ptr hlo_profile_index_map, + std::vector globals) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), text_(text), binary_(binary), gpu_version_(gpu_version), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)) { + assignment_(std::move(assignment)), + constants_(std::move(globals)) { CHECK(has_module() && assignment_); GpuDebugInfoManager::Get()->RegisterModule(module().name(), shared_module(), assignment_); @@ -280,28 +282,23 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { se::ModuleHandle module_handle; TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle)); - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_constant()) { - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase global, - executor->GetUntypedSymbol( - llvm_ir::ConstantBufferAllocationToGlobalName(allocation), - module_handle)); - VLOG(3) << "Resolved global " - << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) - << " to " << global.opaque(); - InsertOrDie(&globals, i, global); + for (const auto& info : constants_) { + const Literal& literal = info.content; - const Literal& literal = - llvm_ir::LiteralForConstantAllocation(allocation); - CHECK(literal.shape().IsArray()); - if (!ShouldEmitLiteralInLlvmIr(literal)) { - VLOG(3) << "H2D memcpy for constant with shape " - << ShapeUtil::HumanString(literal.shape()); - stream->ThenMemcpy(&global, literal.untyped_data(), allocation.size()); - } + TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol( + info.symbol_name, module_handle)); + VLOG(3) << "Resolved global " << info.symbol_name << " to " + << global.opaque(); + + CHECK(literal.shape().IsArray()); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + stream->ThenMemcpy(&global, literal.untyped_data(), literal.size_bytes()); + } + + if (info.allocation_index != -1) { + InsertOrDie(&globals, info.allocation_index, global); } } @@ -334,7 +331,11 @@ StatusOr GpuExecutable::BufferForAllocation( } return registered_buffer; } else if (allocation.is_constant()) { - return FindOrDie(*globals, arg_idx); + auto it = globals->find(arg_idx); + if (it == globals->end()) { + return se::DeviceMemoryBase(); + } + return it->second; } else { // Allocate each allocation that might escape, or is the temp buffer. CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()); @@ -480,6 +481,12 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( ExecutionInput& input = arguments[alias->parameter_number]; MaybeOwningDeviceMemory* maybe_owning_memory = input.MutableBuffer(alias->parameter_index); + if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) { + return InvalidArgument( + "An input was configured to be must-alias at " + "compile time but not donated at runtime: %s", + alias->ToString()); + } if (absl::optional owning = maybe_owning_memory->Release()) { // If the caller passes the ownership of the device memory, reuse it diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 516fa9b269a..613880fd44b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -49,6 +49,12 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: + struct ConstantInfo { + std::string symbol_name; + xla::Literal content; + int allocation_index = -1; + }; + // We need to share ownership of hlo_module and assignment with profiler to // safely keep a reference to these objects during tracing period, thus they // are passed as shared pointers. @@ -58,7 +64,8 @@ class GpuExecutable : public Executable { std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map); + std::unique_ptr hlo_profile_index_map, + std::vector constants); ~GpuExecutable() override; int64 SizeOfGeneratedCodeInBytes() const override; @@ -169,6 +176,8 @@ class GpuExecutable : public Executable { std::map module_globals_ TF_GUARDED_BY(module_handle_mutex_); + std::vector constants_; + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index bb4184ff76f..e56fe4dd74b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr) { + if (instr.opcode() != HloOpcode::kFusion) { + return &instr; + } + auto fused_expression_root = instr.fused_expression_root(); + if (!instr.IsMultiOutputFusion()) { + return fused_expression_root; + } + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; +} + bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, const HloInstruction& instr2) { - // Returns the instructions that determines the emitter used for lowering, - // sometimes referred to as "the real hero". - auto get_real_hero = - [&](const HloInstruction* instr) -> const HloInstruction* { - if (instr->opcode() != HloOpcode::kFusion) { - return instr; - } - auto fused_expression_root = instr->fused_expression_root(); - if (!instr->IsMultiOutputFusion()) { - return fused_expression_root; - } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - }; - // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = get_real_hero(&instr1); - auto* instr_2 = get_real_hero(&instr2); + auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1); + auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -347,8 +345,13 @@ static int64 SharedMemoryUsage(const HloInstruction& instr) { // This limit is also often good for performance. In a fusion with many // operands, each GPU thread likely has to do a lot of work, and so possibly // uses a lot of registers, thus limiting occupancy. +// +// If the fusion is a producer/consumer fusion and instr1 is the +// consumer and instr2 is the producer, set is_consumer_producer_fusion +// to true to enable more fusion. bool FusionWouldBeTooLarge(const HloInstruction& instr1, - const HloInstruction& instr2) { + const HloInstruction& instr2, + bool is_consumer_producer_fusion) { if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) > kSharedMemoryBudgetInBytes) { VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString() @@ -404,6 +407,17 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, // producer -> consumer relationship. operands.erase(&instr1); operands.erase(&instr2); + + // If we generate the same numbers of inputs and outputs as + // before, it won't be bigger after fusion. So accept the fusion. + // As this is a consumer_producer fusion, this does not change the + // consumer numbers of output. So no need to check it. + if (is_consumer_producer_fusion && + operands.size() <= instr1.operands().size()) { + return false; + } + + // Does the new fusion have more operands and outputs than the max? return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; } @@ -490,5 +504,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/, : HloInstruction::FusionKind::kLoop; } +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer) { + return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + // Skip GTE. + return IsConsumerTheOnlyNonRootUser(*user, consumer); + } + if (user == &consumer) { + // `user` is `consumer`. + return true; + } + if (user == user->parent()->root_instruction()) { + // Consumed by ROOT. + return true; + } + return false; + }); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e2a42ecb0a3..5296b8b4096 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -64,14 +64,23 @@ bool IsInputFusibleScatter(const HloInstruction& instr); // Determines whether the combination of `instr1` and `instr2` into a (possibly // multi-output) fusion would be "too large" -- i.e., have more operands and // outputs than is allowed or occupy too much shared memory. +// If the fusion is a producer/consumer fusion and instr1 is the +// consumer and instr2 is the producer, set consumer_producer_fusion +// to true to enable more fusion. bool FusionWouldBeTooLarge(const HloInstruction& instr1, - const HloInstruction& instr2); + const HloInstruction& instr2, + bool is_consumer_producer_fusion = false); // Check if fusing producer and consumer will generate a nested loop, e.g. both // producer and consumer are `reduce-window` HLO instructions. bool CreatesNestedLoop(const HloInstruction& producer, const HloInstruction& consumer); +// Returns the instruction that determines the emitter used for lowering, +// sometimes referred to as "the real hero". +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output @@ -101,6 +110,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer); +// Returns whether `consumer` is the only non-root user of `instr`. +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc index 6287f1e3ca2..31f011fa734 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc @@ -23,26 +23,11 @@ limitations under the License. namespace xla { -StatusOr GpuScatterExpander::Run(HloModule* module) { - auto is_nontrivial_scatter = [](HloInstruction* inst) { - // TODO(b/129698548): Scattering elements larger than 64 bits is not - // supported by XLA:GPU. - return inst->opcode() == HloOpcode::kScatter && - inst->shape().element_type() == C128; - }; - - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(scatter_instrs), is_nontrivial_scatter); - } - - for (HloInstruction* inst : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !scatter_instrs.empty(); +bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + // TODO(b/129698548): Scattering elements larger than 64 bits is not + // supported by XLA:GPU. + return inst->opcode() == HloOpcode::kScatter && + primitive_util::BitWidth(inst->shape().element_type()) > 64; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h index 0818b32474f..92acb909729 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h @@ -20,10 +20,17 @@ limitations under the License. namespace xla { +// Legalizes scatters on the GPU. class GpuScatterExpander : public ScatterExpander { public: + // Although we pass kEliminateAllScatters, we override this behavior in + // InstruuctionMatchesPattern and select only some scatters to expand. + GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {} + absl::string_view name() const override { return "gpu_scatter_expander"; } - StatusOr Run(HloModule* module) override; + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 5d38d1b727c..26a22005dae 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -83,6 +83,8 @@ void HloToIrBindings::EmitBasePointersForHlos( if (non_io_hlo->opcode() == HloOpcode::kConstant) { llvm::Value* global_for_constant = module_->getGlobalVariable( llvm_ir::ConstantHloToGlobalName(*non_io_hlo)); + CHECK(global_for_constant) + << llvm_ir::ConstantHloToGlobalName(*non_io_hlo); BindHloToIrValue(*non_io_hlo, global_for_constant); } else { llvm::Type* pointee_type = @@ -117,11 +119,11 @@ static bool HasMeaningfulName(llvm::Value* value) { return false; } -llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - ShapeIndexView shape_index, - llvm::Value* ir_value) { - llvm::Type* pointee_type = llvm_ir::ShapeToIrType( - ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); +llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, + llvm::IRBuilder<>* b) { + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule()); + llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; @@ -129,9 +131,17 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast(ir_value), dest_type); } else { - typed_ir_value = b_->CreatePointerBitCastOrAddrSpaceCast( + typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast( ir_value, pointee_type->getPointerTo()); } + return typed_ir_value; +} + +llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + ShapeIndexView shape_index, + llvm::Value* ir_value) { + auto typed_ir_value = CastToTypedValue( + ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_); if (!HasMeaningfulName(ir_value)) { ir_value->setName(llvm_ir::IrName(&hlo, "raw")); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 5eef6727801..3813ec6c949 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -116,6 +116,10 @@ class HloToIrBindings { llvm::Value* temp_buffer_base_ = nullptr; }; +// Converts `ir_value` with type i8* to a typed LLVM Value* based on `shape`. +llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, + llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc index 6d663c66b50..d11d1659d51 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/env_var.h" @@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) { return true; } -bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, - const HloInstruction& consumer) { - return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { - if (user->opcode() == HloOpcode::kGetTupleElement) { - // Skip GTE. - return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { - // `user` is `consumer`. - return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; - } - }); -} - // Returns whether `instr` is a profitable candidate to be horizontally fused. // Since the primary benefit of horizontal fusion comes from reducing the // kernel launch overhead, we want to exclude the instructions with diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc new file mode 100644 index 00000000000..58ed9f18840 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -0,0 +1,168 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { +namespace gpu { + +namespace { + +// Gets the representative input shape of the multi-output fusion. +Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { + // Get the HLO that determines the emitter used for lowering. + const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + if (real_hero->operands().empty()) { + // Simply return an empty shape if the representative node has no input + // operands. + return Shape(); + } else { + return real_hero->operand(0)->shape(); + } +} + +class HorizontalInputFusionImpl { + public: + explicit HorizontalInputFusionImpl(HloComputation* computation) + : computation_(computation) {} + + ~HorizontalInputFusionImpl() {} + + StatusOr Run(); + + private: + HloComputation* computation_; +}; // HorizontalInputFusionImpl + +// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to +// right. +bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, + const Shape& shape_b) { + if (shape_a.rank() != shape_b.rank()) { + return shape_a.rank() < shape_b.rank(); + } + auto dims_a = shape_a.dimensions(); + auto dims_b = shape_b.dimensions(); + for (size_t i = 0; i < dims_a.size(); ++i) { + if (dims_a[i] != dims_b[i]) { + return dims_a[i] < dims_b[i]; + } + } + return true; +} + +std::vector FindAndSortFusionCandidates( + HloInstruction* consumer) { + absl::flat_hash_set fusion_instr_set; + for (auto opnd : consumer->operands()) { + HloInstruction* predecessor = opnd->LatestNonGteAncestor(); + // Find out the input fusion instructions whose only consumer is `consumer`. + // This guarantees that fusing these candidates will never create cycles, as + // there is no back edge. + if (IsReduceInputFusion(*predecessor) && + IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { + fusion_instr_set.insert(predecessor); + } + } + + std::vector fusion_instrs; + fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(), + fusion_instr_set.end()); + + std::sort(fusion_instrs.begin(), fusion_instrs.end(), + [&](const HloInstruction* a, const HloInstruction* b) { + Shape shape_a = GetInputShapeForMultiOutputFusion(*a); + Shape shape_b = GetInputShapeForMultiOutputFusion(*b); + if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { + // Sort shapes according to dimensions, so that the same input + // shapes will be placed adjacent each other. + return CompareShapeDimsFromLeftToRight(shape_a, shape_b); + } + // Sort `fusion_instrs` according to instruction counts, because + // we'd like to fuse together computations of similar sizes. + return a->fused_instruction_count() < + b->fused_instruction_count(); + }); + + return fusion_instrs; +} + +StatusOr HorizontalInputFusionImpl::Run() { + bool changed = false; + XLA_VLOG_LINES(3, computation_->ToString()); + + // Using def-to-use order is sound since we do not modify users. + std::vector def_to_use_order = + computation_->MakeInstructionPostOrder(); + for (size_t i = 0; i < def_to_use_order.size(); ++i) { + auto consumer = def_to_use_order[i]; + auto candidates = FindAndSortFusionCandidates(consumer); + if (candidates.empty()) { + continue; + } + + size_t fusion_anchor_id = 0; + for (size_t j = 1; j < candidates.size(); ++j) { + HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; + HloInstruction* fused = candidates[j]; + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { + VLOG(3) << "Fuse " << fused->ToString() << " into " + << fusion_anchor->ToString(); + fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); + changed = true; + } else { + // Update the `fusion_anchor_id` since `fused` is either not + // compatible or not beneficial to be fused with current fusion anchor. + VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused."; + fusion_anchor_id = j; + } + } + } + + return changed; +} + +} // namespace + +StatusOr GpuHorizontalInputFusion::RunOnComputation( + HloComputation* computation) { + HorizontalInputFusionImpl horizontal_fusion_impl(computation); + return horizontal_fusion_impl.Run(); +} + +StatusOr GpuHorizontalInputFusion::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Run horizontal input fusion."; + for (auto* comp : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp)); + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h new file mode 100644 index 00000000000..85313d03412 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace gpu { + +// This optimization pass horizontally fuses kInput fusions to both reduce the +// kernel launch overhead and increase parallelism degree. See +// GpuHorizontalFusion for general description and motivation about horizontal +// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// with kInput fusions. +// +// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// to search the fusion candidates while avoiding creating cycles. That is, +// we simply search for fusion candidates by looking for instructions whose +// outputs are all consumed by the same instruction. This catches the typical +// target cases; often, the candidate instructions are just consumed by the +// ROOT tuple of the entry computation. +class GpuHorizontalInputFusion : public HloModulePass { + public: + GpuHorizontalInputFusion() {} + + absl::string_view name() const override { + return "gpu_horizontal_input_fusion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation*); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc new file mode 100644 index 00000000000..88fdd3ec293 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HorizontalInputFusionTest : public GpuCodegenTest {}; + +TEST_F(HorizontalInputFusionTest, BasicTest) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2 + ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2) + } +)") + .ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); + + const HloInstruction* entry_root = + module->entry_computation()->root_instruction(); + EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())), + (op::GetTupleElement(op::Fusion())))); + + const HloInstruction* fusion = entry_root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(HorizontalInputFusionTest, ManyInputFusions) { + auto module = CreateNewVerifiedModule(); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + HloComputation::Builder builder(TestName()); + std::vector var_outs; + auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024}); + auto output_shape = ShapeUtil::MakeShape(F32, {1024}); + for (int64 i = 0; i < 130; ++i) { + // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) -> + // f32[1024] { + // %param_0 = f32[1024,1024]{1,0} parameter(0) + // %param_1 = f32[] parameter(1) + // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1), + // dimensions={} + // %multiply = f32[1024,1024]{1,0} + // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0} + // %broadcast) + // %constant0 = f32[] constant(0) + // ROOT %reduce = f32[1024]{0} + // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0), + // dimensions={1}, to_apply=%add + // } + HloInstruction* param_var_in = builder.AddInstruction( + HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in")); + HloInstruction* param_alpha = + builder.AddInstruction(HloInstruction::CreateParameter( + i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha")); + auto alpha_broadcasted = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, param_alpha, {})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted)); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, mul, const0, {1}, reduce_computation)); + var_outs.push_back(reduce); + } + builder.AddInstruction(HloInstruction::CreateTuple(var_outs)); + module->AddEntryComputation(builder.Build()); + + // Verify that horizontal fusion is kicked in. Check that there are multiple + // `reduce` instructions fused into the same fusion. 6 is just a randomly + // picked number as we don't exactly know how large the fusion will be + // created due to the `FusionWouldBeTooLarge` constraint. + CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", + /*match_optimized_ir=*/false); + + // Testing with the entire gpu optimization pipeline. + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b994ead17ca..b90e4d85f80 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -60,18 +60,22 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, // Output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { + VLOG(4) << "Producer " << producer->name() << " is a fusion op"; return false; } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). - if (producer->opcode() != HloOpcode::kFusion && - consumer->ReusesOperandElements(operand_index) && - is_expensive(*producer)) { + if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && + ReusesOperandElements(consumer, operand_index)) { + VLOG(4) << "Do not fuse simple, expensive producer " << producer->name() + << " and consumer which reuses operand elements."; return false; } if (!IsProducerConsumerFusible(*producer, *consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { + VLOG(4) << "Producer " << producer->name() + << " is not fusible or should not be fused."; return false; } return true; @@ -87,7 +91,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. - if (FusionWouldBeTooLarge(*consumer, *producer)) { + if (FusionWouldBeTooLarge(*consumer, *producer, + /*is_consumer_producer_fusion=*/true)) { VLOG(5) << "Fusion of (" << producer->ToString() << ") into (" << consumer->ToString() << ") would be too large"; return false; @@ -107,8 +112,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - return !fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( - producer); + if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) { + VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name() + << " would result in overly large code duplication."; + return false; + } + return true; } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc index 4dbd3196ae6..154612824ef 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc @@ -28,7 +28,7 @@ namespace mlir { namespace xla_thunks { XLAThunksDialect::XLAThunksDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { + : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc" diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td index 38602550864..eb203e6917d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td @@ -21,12 +21,6 @@ limitations under the License. include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/OpBase.td" -class LLVMPointerTo - : ContainerType().isPointerTy()">, - "$_self.cast<::mlir::LLVM::LLVMType>().getPointerElementTy()", - "LLVM pointer">; - def XLAThunks_Dialect : Dialect { let name = "xla_thunks"; let cppNamespace = "xla_thunks"; @@ -45,12 +39,12 @@ def AllocationSlice : StructAttr<"AllocationSlice", XLAThunks_Dialect, [ def MemzeroThunkOp : ThunkOp<"execute_memzero_thunk"> { let arguments = (ins - LLVMPointerTo>:$execute_params, + LLVM_PointerTo:$execute_params, AllocationSlice:$allocation_slice ); let results = (outs I<1>:$ok, - LLVMPointerTo>:$error_message + LLVM_PointerTo:$error_message ); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6309d7fcdee..9d4ec358bd3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -433,7 +433,7 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, builder->CreateZExt( builder->CreateBitCast(value, builder->getIntNTy(bit_width)), builder->getIntNTy(32 * num_segments)), - llvm::VectorType::get(builder->getInt32Ty(), num_segments)); + llvm::VectorType::get(builder->getInt32Ty(), num_segments, false)); for (int i = 0; i < num_segments; ++i) { llvm::Value* insert_val; if (target_triple.isNVPTX()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 31203b9c5f0..2215881271c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -30,12 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -98,6 +100,64 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { .MakeElementGenerator(hlo, operand_to_generator)); } +Status IrEmitter::EmitConstants(const HloComputation& computation, + bool lookup_indices) { + for (HloInstruction* instr : computation.instructions()) { + if (instr->opcode() != HloOpcode::kConstant) { + continue; + } + Literal& literal = *Cast(instr)->mutable_literal(); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace( + *ir_emitter_context_->llvm_module()); + + std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, global_name, + /*TLMode=*/llvm::GlobalValue::NotThreadLocal, + /*AddressSpace=*/global_address_space, + /*isExternallyInitialized=*/false); + global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + + GpuExecutable::ConstantInfo info; + info.symbol_name = global_name; + info.content = literal.Clone(); + if (lookup_indices) { + auto maybe_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {}); + if (maybe_slice.ok()) { + info.allocation_index = maybe_slice.ValueOrDie().index(); + } + } + ir_emitter_context_->constants().push_back(std::move(info)); + } + return Status::OK(); +} + Status IrEmitter::HandleConstant(HloInstruction* constant) { return Status::OK(); } @@ -175,10 +235,12 @@ Status IrEmitter::EmitCallToNestedComputation( llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; if (emitted_function == nullptr) { - IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, - ir_emitter_context_); - TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation()); - emitted_function = ir_emitter_nested.GetEmittedFunction(); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_nested, + IrEmitterNested::Create(hlo_module_config_, nested_computation, + ir_emitter_context_)); + TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation()); + emitted_function = ir_emitter_nested->GetEmittedFunction(); } // Operands are in default address space for non-AMDGPU target. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 50e9f06ef08..1a387528220 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -105,6 +105,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilder<>* builder() { return &b_; } + // Emits constants to generated LLVM IR, and also populate related + // inforamtion to ir_emitter_context for large-constant initializations. If + // `lookup_indices` is true, the allocation index associated with the constant + // is also populated. + Status EmitConstants(const HloComputation& computation, bool lookup_indices); + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 9c43f80dc60..34b93ca5b3f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -17,13 +17,19 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #include "llvm/IR/Module.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" namespace xla { namespace gpu { + // IrEmitterContext encapsulates common (mutable and immutable) data structures // used by both IrEmitterNested and IrEmitterUnnested, such as the buffer // assignment and the name uniquer. @@ -34,14 +40,20 @@ class IrEmitterContext { const HloModule* hlo_module, const BufferAssignment* buffer_assignment, std::string platform_name, GpuDeviceInfo gpu_device_info, absl::optional cuda_compute_capability, - const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module) + const HloProfileIndexMap* profile_index_map, + mlir::MLIRContext* mlir_context, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), cuda_compute_capability_(cuda_compute_capability), profile_index_map_(profile_index_map), - llvm_module_(llvm_module) {} + mlir_context_(mlir_context), + llvm_module_(llvm_module) { + mlir_context_ + ->loadDialect(); + } // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; IrEmitterContext& operator=(const IrEmitterContext&) = delete; @@ -57,9 +69,12 @@ class IrEmitterContext { return cuda_compute_capability_; } const HloProfileIndexMap* profile_index_map() { return profile_index_map_; } + mlir::MLIRContext* mlir_context() { return mlir_context_; } llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } + std::vector& constants() { return constants_; } + private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; @@ -67,8 +82,10 @@ class IrEmitterContext { GpuDeviceInfo gpu_device_info_; absl::optional cuda_compute_capability_; const HloProfileIndexMap* profile_index_map_; + mlir::MLIRContext* mlir_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; + std::vector constants_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index e96c5f05e60..5fc091ed8e7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -41,6 +41,16 @@ IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config, : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true), nested_computation_(nested_computation) {} +StatusOr> IrEmitterNested::Create( + const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context) { + std::unique_ptr emitter(new IrEmitterNested( + hlo_module_config, nested_computation, ir_emitter_context)); + TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation, false)); + return emitter; +} + // Nested function serves the same purpose on GPU as a thread-local function on // a CPU. Status IrEmitterNested::CodegenNestedComputation() { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h index ce825851bcc..8ed76cabcda 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -39,12 +39,11 @@ namespace gpu { // class IrEmitterNested : public IrEmitter { public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); + static StatusOr> Create( + const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; IrEmitterNested& operator=(const IrEmitterNested&) = delete; @@ -62,6 +61,13 @@ class IrEmitterNested : public IrEmitter { Status CodegenNestedComputation(); private: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + const HloComputation& nested_computation_; llvm::Function* emitted_function_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 34cdfb4ecf0..f7627c348b6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" @@ -36,6 +37,13 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -82,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -133,7 +142,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::LLVMContext& llvm_context = llvm_module->getContext(); llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), - launch_dims.threads_per_block()); + launch_dims.thread_counts_per_block().x); // Our launch bounds are exact, so we can specify them as reqntidx rather than // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( @@ -143,13 +152,85 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } +int64_t GetAllocationIndex(mlir::BlockArgument func_arg) { + auto func_op = + mlir::cast(func_arg.getParentRegion()->getParentOp()); + return func_op + .getArgAttrOfType(func_arg.getArgNumber(), + "lmhlo.alloc") + .getValue() + .getSExtValue(); +} + +StatusOr GetAllocationSliceForMlir( + mlir::Value v, absl::Span allocations) { + int64 size = v.getType().cast().getSizeInBits() / 8; + + if (auto arg = v.dyn_cast()) { + return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0, + size); + } + + // We match two patterns here: + // * v = ViewOp(arg); + // * v = StaticMemRefCastOp(ViewOp(arg)); + if (mlir::Operation* op = v.getDefiningOp()) { + if (auto cast = mlir::dyn_cast(op)) { + mlir::Value source = cast.getViewSource(); + op = source.getDefiningOp(); + if (!op) { + return Unimplemented("StaticMemRefCastOp has to wrap an op"); + } + } + if (auto view = mlir::dyn_cast(op)) { + return BufferAllocation::Slice( + &allocations[GetAllocationIndex( + view.source().cast())], + mlir::cast(view.byte_shift().getDefiningOp()) + .value() + .cast() + .getValue() + .getSExtValue(), + size); + } + return Unimplemented("StaticMemRefCastOp has to wrap a ViewOp"); + } + + return Unimplemented( + "Operand has to be in the form of ViewOp(arg) or " + "StaticMemRefCastOp(ViewOp(arg))"); +} + +absl::string_view GetHloName(mlir::Operation* op) { + if (auto attr = op->getAttrOfType("name")) { + auto ref = attr.getValue(); + return absl::string_view(ref.data(), ref.size()); + } + return ""; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation) {} + hlo_computation_(hlo_computation), + mlir_scratch_module_(mlir::ModuleOp::create( + mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())), + lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(), + *hlo_computation, mlir_scratch_module_.get()) {} + +StatusOr> IrEmitterUnnested::Create( + const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context) { + auto emitter = std::unique_ptr(new IrEmitterUnnested( + hlo_module_config, hlo_computation, ir_emitter_context)); + TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize()); + TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true)); + return std::move(emitter); +} Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { bindings_.UnbindAllLocalIrValues(); @@ -157,12 +238,11 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } llvm::Function* IrEmitterUnnested::BuildKernelPrototype( - const HloInstruction& inst, - absl::Span args) { + absl::string_view name, absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( - llvm_ir::SanitizeFunctionName(inst.name())); + llvm_ir::SanitizeFunctionName(std::string(name))); // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); @@ -358,7 +438,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { } Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { - AddThunkToThunkSequence(BuildConditionalThunk(conditional)); + TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional)); + AddThunkToThunkSequence(std::move(thunk)); return Status::OK(); } @@ -1037,10 +1118,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { // Build ForThunk for conformant while loops, otherwise build WhileThunk. auto config = xla_while->backend_config(); if (config.ok() && config.ValueOrDie().has_known_trip_count()) { - AddThunkToThunkSequence( + TF_ASSIGN_OR_RETURN( + auto thunk, BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n())); + AddThunkToThunkSequence(std::move(thunk)); } else { - AddThunkToThunkSequence(BuildWhileThunk(xla_while)); + TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while)); + AddThunkToThunkSequence(std::move(thunk)); } return Status::OK(); } @@ -1263,37 +1347,110 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +StatusOr +IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) { + std::unique_ptr& module = scratch_nested_computations_[region]; + if (module == nullptr) { + xla::XlaComputation xla_computation; + TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation)); + TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + module, HloModule::CreateFromProto(xla_computation.proto(), + HloModuleConfig(program_shape))); + } + return module->entry_computation(); +} + Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + MlirEmitterInput result; + + TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort)); + result.op = sort_op; + result.name = GetHloName(sort_op); + // The name in sort op has no semantics, and it's for debug only. If the name + // doesn't exist, we should use a namer (e.g. count-based). + // TODO(timshen): use a namer instead of relying on the HloInstruction names. + if (result.name.empty()) { + result.name = sort->name(); + } + const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); + auto& slice = result.extra_slice; + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + buffer_assignment.GetUniqueSlice(sort, {})); + slice.written = true; + slice.shape = sort->shape(); + + result.thunk_info = GetThunkInfo(sort); + + return EmitSortFromMlir(result); +} + +Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { + absl::Span allocations( + ir_emitter_context_->buffer_assignment().Allocations()); + auto sort_op = mlir::cast(input.op); + + int operand_count = sort_op.operands().size(); + std::vector operand_shapes(operand_count); + std::vector slices; + std::vector output_shapes(sort_op.output().size()); + + for (int i = 0; i < operand_count; i++) { + operand_shapes[i] = + TypeToShape(sort_op.operands()[i].getType().cast()); + } + + // Craft n + 1 slices, where the first n are output parameters, and the last + // is the on-device tuple storage. We don't need n operands because sorting + // kernels are always in-place. + for (int i = 0; i < operand_count; i++) { + output_shapes[i] = + TypeToShape(sort_op.output()[i].getType().cast()); + MlirBufferSlice slice; + TF_ASSIGN_OR_RETURN( + slice.buffer_slice, + GetAllocationSliceForMlir(sort_op.output()[i], allocations)); + slice.written = true; + slice.shape = operand_shapes[i]; + slices.push_back(slice); + } + slices.push_back(input.extra_slice); + std::vector> thunks; - Shape keys_shape = sort->operand(0)->shape(); - int64 dimension_to_sort = sort->dimensions(0); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + + Shape keys_shape = operand_shapes[0]; + int64 dimension_to_sort = sort_op.dimension(); + for (int64 i = 0; i < operand_count; ++i) { // We assume that the layout of all involved operands and outputs is the // same. - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, - sort->operand(i)->shape())); - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. - auto destination_buffer = GetAllocationSlice(*sort, shape_index); - auto source_address = GetAllocationSlice(*sort->operand(i)); + TF_ASSIGN_OR_RETURN( + auto destination_buffer, + GetAllocationSliceForMlir(sort_op.output()[i], allocations)); + TF_ASSIGN_OR_RETURN( + auto source_address, + GetAllocationSliceForMlir(sort_op.operands()[i], allocations)); if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. + VLOG(2) << input.name << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); + /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); } } uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); + VLOG(2) << input.name << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1357,10 +1514,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < sort->operand_count(); ++i) { + for (int64 i = 0; i < operand_count; ++i) { total_shared_memory_needed += - kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i)->shape().element_type()); + kTileSize * + ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); } bool no_tiling = kTileSize < 128 || @@ -1368,34 +1525,51 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { ir_emitter_context_->gpu_device_info().threads_per_block_limit || total_shared_memory_needed > ir_emitter_context_->gpu_device_info().shared_memory_per_block; + VLOG(2) << absl::StreamFormat( + "%s %s use tiling. No tiling if any of the following is true: " + "kTileSize=%d < 128, " + "kThreadsPerBlock=%d > threads_per_block_limit=%d, " + "total_shared_memory_needed=%d > shared_memory_per_block=%d", + input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + ir_emitter_context_->gpu_device_info().threads_per_block_limit, + total_shared_memory_needed, + ir_emitter_context_->gpu_device_info().shared_memory_per_block); uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", + input.name, num_blocks, kThreadsPerBlock); + std::vector ir_arrays; auto emit_kernel = [&](absl::Span xor_masks) { - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + VLOG(2) << absl::StreamFormat( + "%s uses kernel for xor masks [%s]", input.name, + absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { + absl::StrAppendFormat(out, "0x%x", xor_mask); + })); + thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(), + slices, &ir_arrays)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(sort->operand_count()); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + values_arrays.reserve(operand_count); + for (int64 i = 0; i < operand_count; ++i) { + values_arrays.push_back(ir_arrays[i]); } + TF_ASSIGN_OR_RETURN( + const HloComputation* comparator, + GetOrCreateSubComputationFromRegion(&sort_op.comparator())); return llvm_ir::EmitSortInPlace( - dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, kTileSize, [&](absl::Span operands, llvm::Value* output) { - return EmitCallToNestedComputation(*sort->to_apply(), operands, - output); + return EmitCallToNestedComputation(*comparator, operands, output); }); }; std::vector xor_masks; @@ -1421,15 +1595,19 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (!xor_masks.empty()) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } + VLOG(2) << absl::StreamFormat( + "%s requires %d thunks (including any D2D copies)", input.name, + thunks.size()); - AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(sort), std::move(thunks))); - if (sort->operand_count() > 1) { + AddThunkToThunkSequence( + absl::make_unique(input.thunk_info, std::move(thunks))); + if (operand_count > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(*sort, *sort), - ConstructIrArrayForOutputs(*sort), &b_); + llvm_ir::EmitTuple( + ir_arrays[operand_count], + absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); } return Status::OK(); } @@ -1567,24 +1745,6 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } -// Describes how to access a particular subshape for an HLO. For instance if -// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at -// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found -// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we -// dereference twice -- first at index 3, and then at index 4 -- to get the -// address of our buffer. -struct HloBufferSlice { - const HloInstruction* instr; - ShapeIndex hlo_index; - - // The root buffer to look at. - BufferAllocation::Slice buffer_slice; - - // Describes how to dereference starting at that buffer to get to the buffer - // in question. - ShapeIndex gte_index; -}; - // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -1693,22 +1853,22 @@ static std::vector GetHloBufferSlices( return result; } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, bool implements_whole_instruction) { - const BufferAssignment& buffer_assn = - ir_emitter_context_->buffer_assignment(); - - std::vector hlo_slices = - GetHloBufferSlices(inst, buffer_assn); +std::unique_ptr +IrEmitterUnnested::BuildKernelThunkFromBufferSlices( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::function + bind_slice_to_ir_value) { + const auto& buffer_assn = ir_emitter_context_->buffer_assignment(); // Figure out which buffer allocations need to be passed as arguments to our - // kernel. This is simply all of the allocations referenced in hlo_slices, + // kernel. This is simply all of the allocations referenced in slices, // plus the XLA temp buffer (if we have it). We always include the temp // buffer because even if the kernel itself doesn't use it, a nested // subcomputation within the kernel (e.g. a kMap's computation) might. std::unordered_set buffers_needed; - for (const auto& hlo_buffer_slice : hlo_slices) { - buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation()); + for (auto* slice : slices) { + buffers_needed.insert(slice->buffer_slice.allocation()); } absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { @@ -1737,7 +1897,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); + llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. @@ -1771,24 +1931,19 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // For each buffer our kernel might want to touch, bind it to a value derived // from our kernel args. - for (const auto& hlo_buffer_slice : hlo_slices) { - const HloInstruction* instr = hlo_buffer_slice.instr; - const ShapeIndex& index = hlo_buffer_slice.hlo_index; - const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice; - const ShapeIndex& gte_index = hlo_buffer_slice.gte_index; - - VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() - << " is found in slice " << slice.ToString() << " at GTE index " - << gte_index.ToString(); + for (auto* slice : slices) { + const BufferAllocation::Slice& buffer_slice = slice->buffer_slice; + const ShapeIndex& gte_index = slice->gte_index; llvm::Value* loc; - if (slice.allocation()->is_constant()) { + if (buffer_slice.allocation()->is_constant()) { loc = ir_emitter_context_->llvm_module()->getGlobalVariable( - llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation())); + llvm_ir::ConstantBufferAllocationToGlobalName( + *buffer_slice.allocation())); CHECK_NE(loc, nullptr); } else { - loc = InBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()), + {b_.getInt64(buffer_slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -1800,7 +1955,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } - bindings_.BindHloToIrValue(*instr, loc, index); + bind_slice_to_ir_value(slice, loc); } // Bind the temp buffer so that nested subcomputations can find it if they @@ -1812,9 +1967,66 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return absl::make_unique( + return absl::make_unique(thunk_info, non_constant_buffers, + std::string(kernel->getName())); +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst, bool implements_whole_instruction) { + std::vector hlo_slices = + GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment()); + + std::vector slice_ptrs; + slice_ptrs.reserve(hlo_slices.size()); + for (auto& slice : hlo_slices) { + slice_ptrs.push_back(&slice); + } + + return BuildKernelThunkFromBufferSlices( + inst->name(), implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(), - non_constant_buffers, std::string(kernel->getName())); + slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) { + const HloBufferSlice* hlo_buffer_slice = + static_cast(slice); + const HloInstruction* instr = hlo_buffer_slice->instr; + const ShapeIndex& index = hlo_buffer_slice->hlo_index; + VLOG(3) << "Buffer for " << instr->ToString() << " at " + << index.ToString() << " is found in slice " + << hlo_buffer_slice->buffer_slice.ToString() << " at GTE index " + << hlo_buffer_slice->gte_index.ToString(); + + bindings_.BindHloToIrValue(*instr, value, index); + }); +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunkForMlir( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::vector* ir_arrays) { + absl::flat_hash_set buffers_written; + std::vector slice_ptrs; + slice_ptrs.reserve(slices.size()); + for (auto& slice : slices) { + slice_ptrs.push_back(&slice); + if (slice.written) { + buffers_written.insert(slice.buffer_slice); + } + } + + ir_arrays->clear(); + return BuildKernelThunkFromBufferSlices( + name, thunk_info, slice_ptrs, + [&](const BufferSlice* slice, llvm::Value* value) { + const auto& mlir_slice = static_cast(*slice); + + llvm_ir::IrArray ir_array( + CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape); + if (!buffers_written.contains(slice->buffer_slice)) { + ir_array.MarkInvariantOverWholeProgram(&value->getContext()); + } + + ir_arrays->push_back(ir_array); + }); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2021,7 +2233,7 @@ Status CheckConditionalBuffersShareAllocation( } // namespace -std::unique_ptr IrEmitterUnnested::BuildWhileThunk( +StatusOr> IrEmitterUnnested::BuildWhileThunk( const HloInstruction* hlo) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2029,24 +2241,26 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( // Generate thunk sequence for while 'condition'. HloComputation* condition = hlo->while_condition(); - IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, - ir_emitter_context_); - TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); + TF_ASSIGN_OR_RETURN(auto ir_emitter_condition, + IrEmitterUnnested::Create(hlo_module_config_, condition, + ir_emitter_context_)); + TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get())); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); - IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - ir_emitter_context_); - TF_CHECK_OK(body->Accept(&ir_emitter_body)); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_body, + IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); + TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); - return absl::make_unique( + return std::unique_ptr(new WhileThunk( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result - ir_emitter_condition.ConsumeThunkSequence(), - ir_emitter_body.ConsumeThunkSequence()); + ir_emitter_condition->ConsumeThunkSequence(), + ir_emitter_body->ConsumeThunkSequence())); } -std::unique_ptr IrEmitterUnnested::BuildForThunk( +StatusOr> IrEmitterUnnested::BuildForThunk( const HloInstruction* hlo, const int64 loop_limit) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2054,15 +2268,16 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( // Generate thunk sequence for while 'body' (will be used a For loop body). HloComputation* body = hlo->while_body(); - IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - ir_emitter_context_); - TF_CHECK_OK(body->Accept(&ir_emitter_body)); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_body, + IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); + TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); - return absl::make_unique(GetThunkInfo(hlo), loop_limit, - ir_emitter_body.ConsumeThunkSequence()); + return std::unique_ptr(new ForThunk( + GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); } -std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( +StatusOr> IrEmitterUnnested::BuildConditionalThunk( const HloInstruction* hlo) { // Check that the buffers used in conditional are shared with the operands and // result appropriately. @@ -2074,15 +2289,17 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( for (int j = 0; j < hlo->branch_count(); ++j) { branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); HloComputation* branch_computation = hlo->branch_computation(j); - IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation, - ir_emitter_context_); - TF_CHECK_OK(branch_computation->Accept(&ir_emitter)); - branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence())); + TF_ASSIGN_OR_RETURN( + auto ir_emitter, + IrEmitterUnnested::Create(hlo_module_config_, branch_computation, + ir_emitter_context_)); + TF_CHECK_OK(branch_computation->Accept(ir_emitter.get())); + branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); } - return absl::make_unique( + return std::unique_ptr(new ConditionalThunk( GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks)); + std::move(branch_thunks))); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( @@ -2775,6 +2992,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId( }); } +namespace { + +// Obtains the corresponding index of the out_instr in the outputs of the +// `unnested_hlo`. +ShapeIndex CreateShapeIndexForOutputInstruction( + const HloInstruction& unnested_hlo, const HloInstruction& out_instr) { + if (!unnested_hlo.IsMultiOutputFusion()) { + return ShapeIndex({}); + } + const auto& all_outputs = unnested_hlo.fused_expression_root()->operands(); + for (size_t i = 0; i < all_outputs.size(); ++i) { + if (all_outputs[i] == &out_instr) { + return ShapeIndex({static_cast(i)}); + } + } + LOG(FATAL) << " Fusion root does not contain output instruction; " + << " fusion: " << unnested_hlo.ToString() + << ", output instruction: " << out_instr.ToString(); +} + +} // namespace + void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, absl::Span output_instructions, @@ -2782,7 +3021,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( const ReductionCodegenInfo& reduction_info, absl::Span reducers, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); - bool returns_tuple = output_instructions.size() > 1; int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num; InlinedVector input_gens; @@ -2799,7 +3037,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0, e = output_instructions.size(); i != e; ++i) { const HloInstruction* inst = output_instructions[i]; - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + ShapeIndex idx = + CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); } else { @@ -3532,71 +3771,41 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( reduction_dimensions.is_row_reduction); } -Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( +void IrEmitterUnnested::EmitIRForReduction( HloInstruction* unnested_hlo, - absl::Span output_instructions) { - bool returns_tuple = output_instructions.size() > 1; - VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); - + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, const Shape& input_shape) { std::vector reduce_instructions; InlinedVector reduction_output_shape_indices; InlinedVector reducers; - - // Build an initializer thunk to initialize each reduction output. - std::vector> thunks; - for (int i = 0; i < output_instructions.size(); ++i) { + for (size_t i = 0; i < output_instructions.size(); ++i) { if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { continue; } HloInstruction* output_instruction = output_instructions[i]; reduce_instructions.push_back(output_instruction); - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); - reduction_output_shape_indices.push_back(idx); + reduction_output_shape_indices.push_back( + CreateShapeIndexForOutputInstruction(*unnested_hlo, + *output_instruction)); reducers.push_back(output_instruction->to_apply()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(unnested_hlo, idx)); - thunks.push_back(std::move(initializer_thunk)); } + CHECK(reduce_instructions.size() != 0) + << " expect at least one reduce instructions."; - const HloInstruction* first_reduce = reduce_instructions.at(0); - if (output_instructions.size() > 1) { - if (!AreFusedReductionOutputsConsistent(output_instructions, - first_reduce)) { - return InternalError("Inconsistent reduction fusion outputs"); - } - } - - // Build a kernel thunk to compute all the outputs. - std::unique_ptr kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); - - const Shape& input_shape = first_reduce->operand(0)->shape(); - // The layout of a reduction input is either set by LayoutAssignment for - // unnested kReduce or by InstructionFusion for fused kReduce. - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << first_reduce->ToString(); - - ReductionCodegenInfo reduction_info = - ComputeReductionCodegenInfo(unnested_hlo, first_reduce); const KernelMappingScheme& mapping_scheme = - reduction_info.GetKernelMappingScheme(); + reduction_info->GetKernelMappingScheme(); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); - VLOG(3) << "Launch dimensions of " << unnested_hlo->name() - << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() - << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); llvm::Type* index_ty = GetIndexTypeForKernel( unnested_hlo, launch_dimensions.launch_bound(), &b_); - EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions, + EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions, index_ty); EmitElementFunction emit_reduction_tile = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { EmitTileElementForReduction(unnested_hlo, input_shape, - output_instructions, index, reduction_info, + output_instructions, index, *reduction_info, reducers, x_iter_num); }; @@ -3605,70 +3814,185 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl) { - EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, - thread_id_info, tile_height, tile_width, emit_reduction_tile); + EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name, + ksl, thread_id_info, tile_height, tile_width, + emit_reduction_tile); }); - EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, + EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info, reduce_instructions, reduction_output_shape_indices, reducers, tiling_kernel_info); +} +namespace { + +// Returns whether the `instr` is either a constant, a scalar, or a +// broadcasted constant/scalar. +bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { + return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) || + (HloOpcode::kBroadcast == instr.opcode() && + (instr.operand(0)->IsConstant() || + ShapeUtil::IsScalar(instr.operand(0)->shape()))); +} + +// Divides output_instructions into groups. Different groups will be executed +// in parallel. Generally speaking, we'd like to run the reduce instructions +// in parallel without incurring too much recomputation overhead. The current +// heuristic is to place reduce instructions who share nothing or only +// (broadcasted) scalars/constants into different groups; otherwise, they are +// placed in the same group. Non-reduce instructions always go with the reduce +// instructions into the same group so long as they share any predecessors. +std::vector> DivideOutputInstructionsIntoGroups( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + CHECK(!output_instructions.empty()); + if (output_instructions.size() == 1) { + return {{output_instructions[0]}}; + } + + std::vector> disjoint_sets( + output_instructions.size()); + for (size_t i = 0; i < output_instructions.size(); ++i) { + disjoint_sets[i].Get() = output_instructions[i]; + } + + std::unique_ptr reachability_map = + HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation()); + for (auto* instr : unnested_hlo->fused_instructions()) { + std::vector reached_output_ids; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + if (HloOpcode::kReduce == output_instructions[oid]->opcode() && + (IsBroadcastedConstantOrScalar(*instr))) { + // Do not group output reduce instructions through broadcasted + // constants or scalars, as the recomputation should be acceptable. + VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString(); + continue; + } + // Now group output instructions if they have common predecessors. + if (reachability_map->IsReachable(instr, output_instructions[oid])) { + VLOG(3) << "Reaching " << output_instructions[oid]->ToString() + << " from " << instr->ToString(); + reached_output_ids.push_back(oid); + } + } + for (size_t j = 1; j < reached_output_ids.size(); ++j) { + disjoint_sets[reached_output_ids[0]].Merge( + &disjoint_sets[reached_output_ids[j]]); + } + } + // Place output instructions in the same set into the same group. + absl::flat_hash_map> groups; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid)); + } + + std::vector> ret; + absl::c_for_each( + groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); + return ret; +} + +} // namespace + +Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + bool returns_tuple = output_instructions.size() > 1; + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0; i < output_instructions.size(); ++i) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + continue; + } + + ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, idx)); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + const HloInstruction* first_reduce = nullptr; + for (int i = 0; i < output_instructions.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + first_reduce = output_instructions[i]; + break; + } + } + CHECK(first_reduce); + if (output_instructions.size() > 1) { + if (!AreFusedReductionOutputsConsistent(output_instructions, + first_reduce)) { + return InternalError("Inconsistent reduction fusion outputs"); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + // Group output instructions. Each group will be executed in parallel. + std::vector> instr_groups = + DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions); + VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ", + unnested_hlo->ToString()); + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (size_t i = 0; i < instr_groups.size(); ++i) { + // Create a new ReductionCodegenInfo instance as it contains states for + // code generation per reduction group. For now, let's always use the very + // first reduce as representative to construct ReductionCodegenInfo, since + // all the reductions are required to have the same shape and layout as + // verified by `AreFusedReductionOutputsConsistent()`. We can loosen the + // constraint later when the needs arise. + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + auto emit_reduction_func = [&] { + EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info, + input_shape); + }; + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_groups.size(), + llvm::cast(raw_block_id_y)); + llvm::Value* guarding_cond = + b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)); + ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func); + } + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + const KernelMappingScheme& mapping_scheme = + reduction_info.GetKernelMappingScheme(); + // block_y_count is set to instr_groups.size(), so that each reduction group + // can be run in parallel by a different BlockIdy. + LaunchDimensions launch_dimensions( + {/*x=*/mapping_scheme.GetNumberOfBlocks(), + /*y=*/static_cast(instr_groups.size()), + /*z=*/1}, + {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1}); + VLOG(3) << "Launch dimensions of " << unnested_hlo->name() + << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() + << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); thunks.push_back(std::move(kernel_thunk)); - auto sequential_thunk = absl::make_unique( - GetThunkInfo(unnested_hlo), std::move(thunks)); + std::unique_ptr sequential_thunk = + absl::make_unique(GetThunkInfo(unnested_hlo), + std::move(thunks)); AddThunkToThunkSequence(std::move(sequential_thunk)); return Status::OK(); } -Status IrEmitterUnnested::EmitConstantGlobals() { - for (const BufferAllocation& allocation : - ir_emitter_context_->buffer_assignment().Allocations()) { - if (!allocation.is_constant()) { - continue; - } - - const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); - const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); - llvm::ArrayType* global_type = - llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); - llvm::Constant* initializer = - should_emit_initializer - ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) - : llvm::ConstantAggregateZero::get(global_type); - if (should_emit_initializer) { - VLOG(3) << "Emitted initializer for constant with shape " - << ShapeUtil::HumanString(literal.shape()); - } - - // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in - // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that - // merely preserves their names (like available_externally), we also need - // to ensure that they stick around even if they're "unused". - // - // We may have to be more more clever here in the future if we notice that - // we're keeping around too many globals because of their linkage. - unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace( - *ir_emitter_context_->llvm_module()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - global_type, /*isConstant=*/should_emit_initializer, - llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/initializer, - llvm_ir::ConstantBufferAllocationToGlobalName(allocation), - /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/global_address_space, - /*isExternallyInitialized=*/false); - global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); - ir_emitter_context_->llvm_module()->getGlobalList().push_back( - global_for_const); - } - - return Status::OK(); -} - // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 019fcdf21db..c36f0b7840d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -28,6 +29,40 @@ limitations under the License. namespace xla { namespace gpu { +struct BufferSlice { + // The root buffer to look at. + BufferAllocation::Slice buffer_slice; + + // Describes how to dereference starting at that buffer to get to the buffer + // in question. + ShapeIndex gte_index; +}; + +// Describes how to access a particular subshape for an HLO. For instance if +// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at +// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is +// found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we +// dereference twice -- first at index 3, and then at index 4 -- to get the +// address of our buffer. +struct HloBufferSlice : public BufferSlice { + const HloInstruction* instr; + ShapeIndex hlo_index; +}; + +struct MlirBufferSlice : public BufferSlice { + // The buffer is modified by the kernel. + bool written; + + Shape shape; +}; + +struct MlirEmitterInput { + mlir::Operation* op; + absl::string_view name; + Thunk::ThunkInfo thunk_info; + MlirBufferSlice extra_slice; +}; + // Emits LLVM IR for an "unnested computation". // // An unnested computation is an HloComputation which you run by executing one @@ -89,12 +124,14 @@ class IrEmitterUnnested : public IrEmitter, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl)>; - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); IrEmitterUnnested(const IrEmitterUnnested&) = delete; IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + static StatusOr> Create( + const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + // Transfers the ownship of thunk_sequence_ out. std::unique_ptr ConsumeThunkSequence() { return std::make_unique(std::move(thunk_sequence_)); @@ -124,6 +161,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; + Status EmitSortFromMlir(MlirEmitterInput input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; @@ -142,12 +180,13 @@ class IrEmitterUnnested : public IrEmitter, const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk, int unroll_factor); - // Emits LLVM global variables corresponding to constant instructions. - Status EmitConstantGlobals(); - Status Postprocess(HloInstruction* hlo) override; private: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + // Add a owning Thunk object to the thunk sequence. void AddThunkToThunkSequence(std::unique_ptr thunk) override { thunk_sequence_.emplace_back(std::move(thunk)); @@ -264,8 +303,7 @@ class IrEmitterUnnested : public IrEmitter, // Builds the prototype of the IR kernel for `inst` and adds it to the module. // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - absl::Span args); + absl::string_view name, absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( @@ -331,6 +369,16 @@ class IrEmitterUnnested : public IrEmitter, // } // ``` // + // Moreover, a heuristic is implemented to divide the reduce instructions + // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` + // for details about the heuristic.) Reduce instructions in the same group + // will run sequentially while different groups will run in parallel. + // + // we use raw block_id_y to select the reduce groups for execution without + // complicating the index calculation in the code generation of the reduce + // instructions. In other words, a block_id_y is assigned to a group and so + // different groups can be run in parallel. + // // output_instructions: Output instructions in the computation: instruction // itself if it's not a fusion, fusion root if fusion is not multi-output, and // elements of the fusion multi-output tuple otherwise. @@ -363,11 +411,10 @@ class IrEmitterUnnested : public IrEmitter, // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is // expected to have the operand values in it already. If unique_indices - // is false, we will use an atomic update. Using false for unique_indices - // is safe only when it is guaranteed that there are no duplicate - // indices. - // When using unique_indices=true, it is the caller's responsibility to - // ensure there is no overlap. + // is false, we will use an atomic update. Using true for unique_indices + // behaves properly only when it is guaranteed that the indices to be + // updated do not overlap. The caller is responsible for ensuring this is + // the case. Status EmitScatter(Thunk* thunk, HloInstruction* scatter, const llvm_ir::ElementGenerator& scatter_indices_gen, const llvm_ir::ElementGenerator& updates_gen); @@ -478,6 +525,12 @@ class IrEmitterUnnested : public IrEmitter, absl::Span reducers, const TilingKernelInfo& tiling_kernel_info); + // Emits code for reductions in the output_instructions. + void EmitIRForReduction(HloInstruction* unnested_hlo, + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, + const Shape& input_shape); + // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( @@ -490,6 +543,12 @@ class IrEmitterUnnested : public IrEmitter, HloComputation* reducer, llvm::Type* element_type, llvm::Value* partial_result_address); + std::unique_ptr BuildKernelThunkFromBufferSlices( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::function + bind_slice_to_ir_value); + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. 'implements_whole_instruction' specifies whether this @@ -498,6 +557,11 @@ class IrEmitterUnnested : public IrEmitter, std::unique_ptr BuildKernelThunk( const HloInstruction* inst, bool implements_whole_instruction); + std::unique_ptr BuildKernelThunkForMlir( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::vector* ir_arrays); + // Returns a thunk that, given a reduce or select-and-scatter op, // initializes its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( @@ -505,17 +569,18 @@ class IrEmitterUnnested : public IrEmitter, // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + StatusOr> BuildWhileThunk(const HloInstruction* hlo); // Returns a ForThunk which executes 'loop_limit' invocations of a thunk // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); + StatusOr> BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); // Returns a ConditionalThunk which executes the thunk sequence for the // 'branch_computation' corresponding to the predicate/branch_index of the // given conditional instruction. - std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + StatusOr> BuildConditionalThunk( + const HloInstruction* hlo); // Emits current thread id with the given type. // @@ -545,6 +610,9 @@ class IrEmitterUnnested : public IrEmitter, absl::optional thread_id_filter = absl::nullopt, absl::optional block_id_filter = absl::nullopt); + StatusOr GetOrCreateSubComputationFromRegion( + mlir::Region* region); + // Returns the last generated thunk. Thunk* LastThunk() const { return thunk_sequence_.back().get(); } @@ -555,6 +623,14 @@ class IrEmitterUnnested : public IrEmitter, // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; + + mlir::OwningModuleRef mlir_scratch_module_; + + // This is for cache-purpose only. It has no significant semantics. + mlir::LhloDialectEmitter lhlo_scratch_emitter_; + + absl::flat_hash_map> + scratch_nested_computations_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 19fef37db7e..6c138258aa0 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - return ExecuteKernelOnStream(*kernel, buffer_args, - launch_dimensions.threads_per_block(), - launch_dimensions.block_count(), params.stream); + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3668a521ec7..c23e8112cb0 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -26,8 +26,11 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), - launch_dims.threads_per_block()); + LaunchDimensions::Dim3D block_counts = launch_dims.block_counts(); + LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block(); + out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", + block_counts.x, block_counts.y, block_counts.z, + thread_counts.x, thread_counts.y, thread_counts.z); return out; } diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 1a5a9d618e4..dbe5a037e43 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -29,24 +29,37 @@ namespace gpu { // number of threads per block. class LaunchDimensions { public: + struct Dim3D { + int64 x, y, z; + }; + // The default constructor creates a launch dimension that indicate // single-threaded execution. - LaunchDimensions() : block_count_(1), threads_per_block_(1) {} + LaunchDimensions() + : block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {} - LaunchDimensions(int64 block_count, int64 threads_per_block) - : block_count_(block_count), threads_per_block_(threads_per_block) {} + LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block) + : block_counts_({block_x_count, 1, 1}), + thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {} - bool IsSinglethreaded() const { - return block_count_ == 1 && threads_per_block_ == 1; + LaunchDimensions(const Dim3D& block_counts, + const Dim3D& thread_counts_per_block) + : block_counts_(block_counts), + thread_counts_per_block_(thread_counts_per_block) {} + + Dim3D block_counts() const { return block_counts_; } + + Dim3D thread_counts_per_block() const { return thread_counts_per_block_; } + + int64 launch_bound() const { + return block_counts_.x * thread_counts_per_block_.x * block_counts_.y * + thread_counts_per_block_.y * block_counts_.z * + thread_counts_per_block_.z; } - int64 block_count() const { return block_count_; } - int64 threads_per_block() const { return threads_per_block_; } - int64 launch_bound() const { return block_count() * threads_per_block(); } - private: - int64 block_count_; - int64 threads_per_block_; + Dim3D block_counts_; + Dim3D thread_counts_per_block_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 1228a1b4823..04af67a70b9 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -62,8 +62,10 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/env_var.h" namespace xla { namespace gpu { @@ -86,14 +88,21 @@ static string GetSmName(std::pair compute_capability) { int sm_version = 30; // If the current compute capability isn't known, fallback to the // most recent version before it. - for (int v : {75, 72, 70, 62, 61, 60, 53, 52, 50, 37, 35, 32, 30}) { + int supported_versions[] = {75, 72, 70, 62, 61, 60, 53, + 52, 50, 37, 35, 32, 30}; + for (int v : supported_versions) { if (v <= compute_capability_version) { sm_version = v; break; } } - if (sm_version != compute_capability_version) { + // If the current CC isn't supported by LLVM and it is newer then + // the max supported LLVM version, do not warn about it. The end + // user can't do anything about this. PTX compiled for SM75 will + // run on SM80 too. + if (sm_version != compute_capability_version && + compute_capability_version < supported_versions[0]) { LOG(WARNING) << "Unknown compute capability (" << compute_capability.first << ", " << compute_capability.second << ") ." << "Defaulting to telling LLVM that we're compiling for sm_" @@ -570,6 +579,60 @@ static std::vector GetROCDLPaths(int amdgpu_version, return result; } +struct HsacoCacheEntry { + uint64 hash; + std::string ir; + int gfx; + std::vector hsaco; +}; + +struct HsacoCache { + protected: + std::vector cache; + std::mutex m_mutex; + int request_count = 0; + int hit_count = 0; + + public: + static bool Find(const std::string& ir, uint64_t& hash, int gfx, + std::vector& hsaco); + static void Add(const std::string& ir, uint64_t hash, int gfx, + const std::vector& hsaco); +}; + +static HsacoCache g_hsacoCache; + +bool HsacoCache::Find(const std::string& ir, uint64_t& hash, int gfx, + std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + hash = std::hash{}(ir); + bool hit = false; + for (auto& x : g_hsacoCache.cache) { + if (x.hash != hash) continue; + if (x.gfx != gfx) continue; + if (x.ir != ir) continue; + hsaco = x.hsaco; + hit = true; + break; + } + g_hsacoCache.request_count++; + if (hit) g_hsacoCache.hit_count++; + if (!(g_hsacoCache.request_count % 50)) + VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, " + << g_hsacoCache.hit_count << " hits"; + return hit; +} + +void HsacoCache::Add(const std::string& ir, uint64_t hash, int gfx, + const std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1); + g_hsacoCache.cache.back().ir = ir; + g_hsacoCache.cache.back().hash = hash; + g_hsacoCache.cache.back().gfx = gfx; + g_hsacoCache.cache.back().hsaco = hsaco; +} + // Emits the given module to HSA Code Object. target_machine is an initialized // TargetMachine for the AMDGPU target. StatusOr> EmitModuleToHsaco( @@ -584,18 +647,29 @@ StatusOr> EmitModuleToHsaco( std::string tempdir_name = tempdir_vector.front(); VLOG(1) << "Compile-time artifacts located at: " << tempdir_name; + bool keep_tempfiles = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES", + /*default_val=*/false, + &keep_tempfiles)); // Prepare filenames for all stages of compilation: // IR, binary ISA, and HSACO. - std::string ir_filename = absl::StrCat(module->getModuleIdentifier(), ".ll"); + std::string random_number = std::to_string(tensorflow::random::New64()); + std::string ir_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + ".ll"); std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename); + std::string ir_opt_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll"); + std::string ir_opt_path = + tensorflow::io::JoinPath(tempdir_name, ir_opt_filename); + std::string isabin_filename = - absl::StrCat(module->getModuleIdentifier(), ".o"); + absl::StrCat(module->getModuleIdentifier(), random_number + ".o"); std::string isabin_path = tensorflow::io::JoinPath(tempdir_name, isabin_filename); std::string hsaco_filename = - absl::StrCat(module->getModuleIdentifier(), ".hsaco"); + absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco"); std::string hsaco_path = tensorflow::io::JoinPath(tempdir_name, hsaco_filename); @@ -613,7 +687,7 @@ StatusOr> EmitModuleToHsaco( std::string module_id = module->getModuleIdentifier(); IrDumpingPassManager codegen_passes( ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-amdgpu.dummy"), + random_number + "-amdgpu.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -627,6 +701,12 @@ StatusOr> EmitModuleToHsaco( codegen_passes.run(*module); isabin_fs->flush(); + if (keep_tempfiles) { + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::F_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + } // Locate lld. // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after // ROCm-Device-Libs PR. @@ -652,9 +732,9 @@ StatusOr> EmitModuleToHsaco( int lld_result = llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), llvm::None, {}, 0, 0, &error_message); - if (lld_result) { - return xla::InternalError("ld.lld execute fail: %s", error_message); + return xla::InternalError("ld.lld execute fail: %s, error code %d", + error_message, lld_result); } // Read HSACO. @@ -664,6 +744,12 @@ StatusOr> EmitModuleToHsaco( std::vector hsaco(hsaco_file_size); hsaco_file.seekg(0, std::ios::beg); hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + hsaco_file.close(); + if (!keep_tempfiles) { + remove(ir_path.c_str()); + remove(isabin_path.c_str()); + remove(hsaco_path.c_str()); + } return hsaco; } @@ -728,6 +814,20 @@ StatusOr> CompileToHsaco( std::vector hsaco; std::unique_ptr target_machine; + std::string str; + llvm::raw_string_ostream stream(str); + stream << *module; + // Delete the first two lines, since they usually vary even when the rest of + // the code is the same (but verify that they are what we expect). + if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") { + auto pos = str.find("\n"); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") { + auto pos = str.find("\n"); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + str += hlo_module_config.compilation_cache_key(); { tensorflow::profiler::TraceMe activity( [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, @@ -739,6 +839,21 @@ StatusOr> CompileToHsaco( return xla::InternalError( "Incompatible AMD GCN ISA version was specified."); } + uint64_t hash; + if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) { + VLOG(1) << "HSACO cache hit"; + return hsaco; + } + VLOG(1) << "HSACO cache miss"; + bool dump_lls = false; + if (dump_lls) { + static int hsaco_count = 0; + std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll"; + hsaco_count++; + std::ofstream ofs(name); + ofs << str; + ofs.close(); + } llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); // Construct LLVM TargetMachine for AMDGPU. @@ -754,6 +869,7 @@ StatusOr> CompileToHsaco( // Lower optimized LLVM module to HSA code object. TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); + HsacoCache::Add(str, hash, *amdgpu_version, hsaco); } return hsaco; } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index f9937ba77de..6b7b31e8288 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -75,7 +75,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::vector array_indices; llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, static_cast(block_id)); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); @@ -85,16 +85,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // %ntid.x is currently specified as 1024. llvm::Value* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, static_cast(thread_id)); thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = b_->CreateAdd( - b_->CreateMul(block_id, - llvm::ConstantInt::get( - index_type, launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + b_->CreateMul( + block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,9 +110,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Intrinsic::assume, {b_->CreateICmpULT( linear_index_base, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x * + launch_dimensions_.block_counts().x), "linear_index_in_range")}, {}, b_); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index d7468a31377..8ea7c57c978 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -209,16 +209,18 @@ StatusOr> CreateKernel( Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream) { + const LaunchDimensions& dims, se::Stream* stream) { static constexpr int kKernelArgsLimit = 1024; auto kernel_args = absl::make_unique>(); for (const se::DeviceMemoryBase& buf : args) { kernel_args->add_device_memory_argument(buf); } - return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), - se::BlockDim(block_count), kernel, - *kernel_args); + LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dims.block_counts(); + return stream->parent()->Launch( + stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, + *kernel_args); } se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 0a5e0e93a51..6696d1957b3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -71,8 +72,7 @@ StatusOr> CreateKernel( // Runs loaded kernel on the stream with the provided arguments. Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream); + const LaunchDimensions& dims, se::Stream* stream); // Create GpuAsmOpts out of HloModuleConfig. se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a2bddd2d0d7..f6e3e965166 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -219,6 +219,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "parallel_reduction_test", + srcs = [ + "parallel_reduction_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], @@ -375,6 +397,8 @@ tf_cc_test( ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_fusible", + "//tensorflow/compiler/xla/service/gpu:instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -458,6 +482,35 @@ xla_test( ], ) +tf_cc_test( + name = "sorting_test", + srcs = [ + "sorting_test.cc", + ], + tags = tf_cuda_tests_tags() + [ + "no_rocm", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_binary( name = "hlo_to_llvm_ir", srcs = ["hlo_to_llvm_ir.cc"], @@ -499,8 +552,15 @@ filegroup( # Binary with only the thunks dialect registered, for testing purposes. tf_cc_binary( name = "xla-thunks-opt", + srcs = ["xla_thunks_opt.cc"], deps = [ - "//tensorflow/compiler/mlir:tf_mlir_opt_main", - "//tensorflow/compiler/xla/service/gpu:xla_thunks_dialect_registration", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla/service/gpu:xla_thunks_ops", + "//tensorflow/core:lib", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Shape", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc index 674b436a8e3..811705d2b17 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -54,6 +56,37 @@ TEST_F(GpuFusionTest, FusedReshape) { )"); } +// Check that we limit the number of operands to fusions we create. +TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) { + constexpr int64 kNumParams = kMaxOperandsAndOutputsPerFusion + 1; + + // Compute + // p0 + p1 + p2 + ... + pn, + // Use so many parameters that they do not fit into one fusion. + auto module = CreateNewVerifiedModule(); + HloComputation::Builder b(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {10, 100}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 2}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {10, 2 * kNumParams}); + HloInstruction* input = + b.AddInstruction(HloInstruction::CreateParameter(0, input_shape, "p")); + + std::vector slice_params; + for (int64 i = 0; i < kNumParams; ++i) { + slice_params.push_back(b.AddInstruction(HloInstruction::CreateSlice( + slice_shape, input, {0, 0}, {10, 2}, {1, 1}))); + } + b.AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, slice_params, 1)); + module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).ValueOrDie()); + EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() == + HloOpcode::kFusion); + for (HloInstruction* instr : module->entry_computation()->instructions()) { + EXPECT_TRUE(instr->opcode() != HloOpcode::kSlice); + } +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 1e39a4deaa7..8ec00d73711 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuNoAliasTest, Concat) { hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), - R"(CHECK-LABEL: define void @fusion + R"(CHECK-LABEL: define{{.*}}void @fusion CHECK-SAME: i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %[[OUTPUT_ALLOC:[a-z0-9]*]] CHECK: %fusion.raw = {{.*}} %[[OUTPUT_ALLOC]])", /*match_optimized_ir=*/false); diff --git a/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc new file mode 100644 index 00000000000..06e547dfe34 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { + +namespace { + +class ParallelReductionTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test contains a MOF fusion and the XLA optimizer passes + // don't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(ParallelReductionTest, TwoParallelReductions) { + const char* hlo_text = R"( +HloModule TwoParallelReductions + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +ENTRY %cluster { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + ROOT %fusion = (f32[], f32[]) + fusion(%param0, %param1), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK-NOT: reduce-group-2 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ManyParallelReductions) { + std::unique_ptr module = CreateNewVerifiedModule(); + // Simply use a number not too large to avoid long compilation time + // and not too small for meaningful test. + const size_t num_reduces = 32; + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + HloInstruction* lhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + HloInstruction* rhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + Shape input_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloComputation* fusion_computation; + { + auto fusion_builder = HloComputation::Builder("fusion_computation"); + std::vector outputs; + HloInstruction* constant = fusion_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + HloInstruction* output = + fusion_builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, param, constant, {0}, reduce_computation)); + outputs.push_back(output); + } + fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs)); + fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build()); + } + + HloComputation::Builder b(TestName()); + std::vector entry_params; + std::vector output_shapes; + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = b.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + entry_params.push_back(param); + output_shapes.push_back(output_shape); + } + b.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape(output_shapes), + HloInstruction::FusionKind::kInput, entry_params, fusion_computation)); + module->AddEntryComputation(b.Build()); + + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ThreeReductionGroups) { + const char* hlo_text = R"( +HloModule ThreeReductionGroups + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + %constant0 = f32[] constant(0) + // %mul0, %reduce0, and %reduce1 should go into a group. + %broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={} + %mul0 = f32[1024,128] multiply(param0, broadcast0) + %reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce2 and %reduce3 should go into another group. + %reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce4 and %mul2 should go into the other group, although broadcast0 is + // reused. + %mul1 = f32[1024,128] multiply(param2, broadcast0) + %reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32 + %mul2 = f32[1024,128] multiply(param2, param2) + ROOT %tuple = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0) +} + +ENTRY %cluster { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + ROOT %fusion = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK: reduce-group-2 +CHECK-NOT: reduce-group-3 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc index 215c2e627ae..5f97452ff71 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -336,8 +336,17 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(optimized_module), - R"( + const se::DeviceDescription& device_description = + backend().default_stream_executor()->GetDeviceDescription(); + int cc_major = 0, cc_minor = 0; + device_description.cuda_compute_capability(&cc_major, &cc_minor); + + string expected; + if (cc_major < 6) { + // We do not vectorize for GPU before Pascal. + expected = "CHECK-NOT: ld.global.nc.v2.f32"; + } else { + expected = R"( CHECK: ld.global.nc.v2.f32 CHECK: st.global.v2.f32 CHECK: st.global.v2.f32 @@ -350,7 +359,9 @@ CHECK: st.global.v2.f32 CHECK: ld.global.nc.v2.f32 CHECK: st.global.v2.f32 CHECK: st.global.v2.f32 -)"); +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index c9e7daeb3bc..f625abe6612 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -1,6 +1,6 @@ // RUN: hlo_to_llvm_ir %s | FileCheck %s -// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 @@ -43,8 +43,8 @@ // CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4 // CHECK: br label %[[VAL_23]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -72,7 +72,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) { +// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) { // CHECK: entry: // CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 @@ -104,8 +104,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4 // CHECK: br label %[[VAL_57]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} @@ -131,7 +131,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 // CHECK: %[[VAL_98:.*]] = alloca i32, align 4 @@ -188,8 +188,8 @@ ENTRY main { // CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1 // CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -216,7 +216,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) { +// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 @@ -253,8 +253,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4 // CHECK: br label %[[VAL_138]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo new file mode 100644 index 00000000000..4d29a8df116 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo @@ -0,0 +1,382 @@ +// RUN: hlo_to_llvm_ir %s | FileCheck %s + +HloModule TestModule + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define internal void @region_0_4(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_3_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_0_1_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_1_2_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] +// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_3_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_3_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 +// CHECK-NEXT: ret void + +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP11:%.*]] = xor i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP12:%.*]] = icmp slt i64 [[TMP8]], [[TMP11]] +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], 3 +// CHECK-NEXT: [[TMP14:%.*]] = and i1 [[TMP12]], [[TMP13]] +// CHECK-NEXT: br i1 [[TMP14]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP15]], float* [[TMP16]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP17:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP17]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP18:%.*]] = load float, float* [[TMP15]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP18]], float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] +ENTRY main { + x = f32[2, 3] parameter(0) + ROOT sort = f32[2, 3] sort(x), dimensions={1}, to_apply=compare +} + +// ----- + +HloModule TestModule + +compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT +} + +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP10]], 2 +// CHECK-NEXT: [[TMP14:%.*]] = xor i64 [[TMP13]], 1 +// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], [[TMP14]] +// CHECK-NEXT: [[TMP16:%.*]] = icmp slt i64 [[TMP14]], 3 +// CHECK-NEXT: [[TMP17:%.*]] = and i1 [[TMP15]], [[TMP16]] +// CHECK-NEXT: br i1 [[TMP17]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP18]], i32* [[TMP19]], float* [[TMP20]], float* [[TMP21]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP22:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP22]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = load i32, i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: store i32 [[TMP24]], i32* [[TMP26]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = load float, float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 +// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: store float [[TMP28]], float* [[TMP30]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define internal void @region_0_6(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_5_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_2_3_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_3_4_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] +// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_5_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_5_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 +// CHECK-NEXT: ret void + +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP13:%.*]] = xor i64 [[TMP10]], 3 +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP10]], [[TMP13]] +// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], 3 +// CHECK-NEXT: [[TMP16:%.*]] = and i1 [[TMP14]], [[TMP15]] +// CHECK-NEXT: br i1 [[TMP16]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP17]], i32* [[TMP18]], float* [[TMP19]], float* [[TMP20]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP21:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP21]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8* +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0 +// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8 +// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8* +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1 +// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8 +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2 +// CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 1 +// CHECK-NEXT: [[TMP19:%.*]] = icmp slt i64 [[TMP17]], [[TMP18]] +// CHECK-NEXT: [[TMP20:%.*]] = icmp slt i64 [[TMP18]], 3 +// CHECK-NEXT: [[TMP21:%.*]] = and i1 [[TMP19]], [[TMP20]] +// CHECK-NEXT: br i1 [[TMP21]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP22]], i32* [[TMP23]], float* [[TMP24]], float* [[TMP25]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP26:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP26]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP27:%.*]] = load i32, i32* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = load i32, i32* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: store i32 [[TMP27]], i32* [[TMP29]], align 4 +// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: store i32 [[TMP28]], i32* [[TMP30]], align 4 +// CHECK-NEXT: [[TMP31:%.*]] = load float, float* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP32:%.*]] = load float, float* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: store float [[TMP31]], float* [[TMP33]], align 4 +// CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: store float [[TMP32]], float* [[TMP34]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] +ENTRY main { + x = s32[2, 3] parameter(0) + y = f32[2, 3] parameter(1) + ROOT sort = (s32[2, 3], f32[2, 3]) sort(x, y), dimensions={1}, to_apply=compare +} diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc b/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc new file mode 100644 index 00000000000..197a0c6cfeb --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +class SortingTest : public GpuCodegenTest { + protected: + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // Disable layout_assignment to use the preassigned layouts. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; + } +}; + +TEST_F(SortingTest, Regression1) { + const char* hlo_text = R"( +HloModule TestModule + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +ENTRY TestComputation { + x = f32[3, 2]{1, 0} parameter(0) + x.copy = f32[3, 2]{0, 1} copy(x) + ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare +} + +)"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc b/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc new file mode 100644 index 00000000000..97c3b3a5bde --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/xla_thunks_opt.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h" +#include "tensorflow/core/platform/init_main.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + registry.insert(); + registry.insert(); + return failed( + mlir::MlirOptMain(argc, argv, "XLA-Thunk pass driver\n", registry)); +} diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 10751752571..2e2b668eba7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_live_range.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -55,9 +56,10 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // rather than summing each computation, since it gives us a better lower // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), *module, - schedule, *alias_analysis, size_function)); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), *module, + schedule, *alias_analysis, size_function)); return result.heap_size; } @@ -69,10 +71,11 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const absl::flat_hash_map* memory_by_computation) { TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), - computation, sequence, alias_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), computation, + sequence, alias_analysis, size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -82,16 +85,17 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const LogicalBuffer::SizeFunction& size_function, const HloSchedule* schedule) { TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), - computation, sequence, alias_analysis, size_function, - schedule, HeapSimulator::Options())); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), computation, + sequence, alias_analysis, size_function, schedule, + HeapSimulator::Options())); return result.heap_size; } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloModule& module, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, const HloModule& module, const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); @@ -108,8 +112,9 @@ StatusOr HeapSimulator::Run( } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloComputation& computation, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, + const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, @@ -128,8 +133,9 @@ StatusOr HeapSimulator::Run( } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloComputation& computation, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, + const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule, @@ -326,12 +332,13 @@ Status HeapSimulator::RunComputation( } HeapSimulator::HeapSimulator( - std::unique_ptr algorithm, + std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule, const absl::flat_hash_map* memory_by_computation) - : no_fragmentation_stats_(absl::make_unique()), + : no_fragmentation_stats_( + absl::make_unique>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), @@ -396,8 +403,8 @@ void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared, shared); } -HeapSimulator::Result HeapSimulator::Finish() { - Result result = algorithm_->Finish(); +HeapSimulator::Result HeapSimulator::Finish() { + Result result = algorithm_->Finish(); // Post-process the result to add chunks for shared buffers. An empty chunk // map means that either no buffers were allocated, or the heap was only @@ -411,7 +418,7 @@ HeapSimulator::Result HeapSimulator::Finish() { } // Fragmentation is the difference between the actual and ideal sizes. - const Result no_frag_result = no_fragmentation_stats_->Finish(); + const Result no_frag_result = no_fragmentation_stats_->Finish(); result.fragmentation_size = result.heap_size - no_frag_result.heap_size; // Copy the debug trace we collected to the final result. @@ -437,14 +444,17 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) { +template +void NoFragmentationStatsHeap::Alloc(const BufferType* buffer, + int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::AccountForSubcomputationMemory( +template +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, int64 alloc_size_by_instruction, const absl::flat_hash_map& memory_by_computation) { @@ -472,11 +482,15 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } -void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) { +template +void NoFragmentationStatsHeap::Free(const BufferType* buffer, + int64 size) { current_heap_size_ -= size; } -HeapSimulator::Result NoFragmentationStatsHeap::Finish() { +template +HeapSimulator::Result +NoFragmentationStatsHeap::Finish() { // The result.chunk_map is empty, since we only collect stats, and don't // actually compute chunk assignments. Result result; @@ -484,7 +498,8 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( +template +GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( int64 alignment, Type type) : alignment_(alignment) { if (type == kTemporal) { @@ -495,8 +510,10 @@ GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( } } -GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare -GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const { +template +typename GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare +GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() + const { return [&](const BufferInterval& x, const BufferInterval& y) { int64 x_end = x.end; for (auto colocation : GetTransitiveColocations(x)) { @@ -515,12 +532,14 @@ GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const { if (x.size != y.size) { return x.size > y.size; } - return x.buffer->id() < y.buffer->id(); + return *x.buffer < *y.buffer; }; } -/*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare -GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { +template +/*static*/ typename GlobalDecreasingSizeBestFitHeap< + BufferType>::BufferIntervalCompare +GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { return [&](const BufferInterval& x, const BufferInterval& y) { if (x.size != y.size) { return x.size > y.size; @@ -528,12 +547,13 @@ GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { if (x.end - x.start != y.end - y.start) { return x.end - x.start > y.end - y.start; } - return x.buffer->id() < y.buffer->id(); + return *x.buffer < *y.buffer; }; } -void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer, - int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::Alloc( + const BufferType* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -546,9 +566,9 @@ void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer, ++current_time_; } -void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer, - const HloValue* share_with, - int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::ShareWith( + const BufferType* buffer, const BufferType* share_with, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -562,15 +582,16 @@ void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer, ++current_time_; } -absl::flat_hash_set -GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( +template +absl::flat_hash_set +GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( const BufferInterval& interval) const { - absl::flat_hash_set result; + absl::flat_hash_set result; std::vector worklist = {&interval}; while (!worklist.empty()) { const BufferInterval* item = worklist.back(); worklist.pop_back(); - for (const HloValue* buffer_colocated : item->colocations) { + for (const BufferType* buffer_colocated : item->colocations) { result.insert(buffer_colocated); worklist.push_back(&buffer_intervals_.at(buffer_colocated)); } @@ -579,7 +600,9 @@ GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( return result; } -void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::Free(const BufferType* buffer, + int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { return; @@ -785,7 +808,9 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( return result; } -HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { +template +HeapSimulator::Result +GlobalDecreasingSizeBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -803,8 +828,10 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { return result_; } -std::vector -GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { +template +std::vector< + typename GlobalDecreasingSizeBestFitHeap::BufferInterval> +GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { std::vector sorted_buffer_intervals; for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); @@ -814,8 +841,9 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { return sorted_buffer_intervals; } -GlobalDecreasingSizeBestFitHeap::ChunkCandidate -GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( +template +typename GlobalDecreasingSizeBestFitHeap::ChunkCandidate +GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval, int64 preferred_offset) const { VLOG(1) << "Finding chunks for buffer: " @@ -912,9 +940,12 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( return chunk_candidate; } -void GlobalDecreasingSizeBestFitHeap::CommitChunk( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval, - GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) { +template +void GlobalDecreasingSizeBestFitHeap::CommitChunk( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& + buffer_interval, + GlobalDecreasingSizeBestFitHeap::ChunkCandidate + chunk_candidate) { // Update the maximum heap size according to the one determined by the chunk // candidate. result_.heap_size = chunk_candidate.heap_size; @@ -930,13 +961,16 @@ void GlobalDecreasingSizeBestFitHeap::CommitChunk( AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk); } -void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer, - Chunk chunk) { +template +void GlobalDecreasingSizeBestFitHeap::AddToChunkMap( + const BufferType* buffer, Chunk chunk) { const auto emplace_result = result_.chunk_map.emplace(buffer, chunk); DCHECK(emplace_result.second); } -HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { +template +HeapSimulator::Result +ChooseBestHeapAlgorithm::Finish() { DCHECK(!algorithms_.empty()); std::vector results(algorithms_.size()); int64 min_size = INT64_MAX; @@ -953,4 +987,9 @@ HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { return results[min_size_index]; } +template class GlobalDecreasingSizeBestFitHeap; +template class GlobalDecreasingSizeBestFitHeap< + MemorySpaceAssignmentRepacker::AllocationBlock>; +template class ChooseBestHeapAlgorithm; + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index d3b781ded0c..b47ff685139 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -40,7 +40,9 @@ limitations under the License. namespace xla { // Forward declare classes defined below. +template class HeapAlgorithm; +template class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular @@ -66,9 +68,10 @@ class HeapSimulator { }; // Result represents the result of the heap simulation. + template struct Result { // The assignment of buffers to chunks. - absl::flat_hash_map chunk_map; + absl::flat_hash_map chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -128,19 +131,19 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run(std::unique_ptr algorithm, - const HloModule& module, - const HloSchedule& schedule, - const HloAliasAnalysis& alias_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr> Run( + std::unique_ptr> algorithm, + const HloModule& module, const HloSchedule& schedule, + const HloAliasAnalysis& alias_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions // in the computation. The result is invalid if instructions are not run in // exactly this sequence. - static StatusOr Run( - std::unique_ptr algorithm, + static StatusOr> Run( + std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, @@ -151,8 +154,8 @@ class HeapSimulator { // Same as above, but runs on with a schedule that covers all nested // computations. - static StatusOr Run( - std::unique_ptr algorithm, + static StatusOr> Run( + std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, @@ -163,7 +166,7 @@ class HeapSimulator { // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator(std::unique_ptr algorithm, + HeapSimulator(std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule = nullptr, const absl::flat_hash_map* @@ -187,7 +190,7 @@ class HeapSimulator { // Two buffers belong to the same shared group. // Eight of the buffer has no shared group assigned. bool InSameSharedGroup(const HloValue* left, const HloValue* right); - Result Finish(); + Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const HloValue* buffer, const HloInstruction* instruction, @@ -196,8 +199,9 @@ class HeapSimulator { // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the // simulation. - const std::unique_ptr no_fragmentation_stats_; - const std::unique_ptr algorithm_; + const std::unique_ptr> + no_fragmentation_stats_; + const std::unique_ptr> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; // schedule_ is set by buffer assignment, and memory_by_computation_ is @@ -220,15 +224,16 @@ class HeapSimulator { // offsets to buffers. A sequence of Alloc / Free calls will be made, with the // same semantics as a regular memory heap. Finish will be called at the end to // collect the simulation results. +template class HeapAlgorithm { public: using Chunk = HeapSimulator::Chunk; - using Result = HeapSimulator::Result; + using Result = HeapSimulator::Result; virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const HloValue* buffer, int64 size) = 0; + virtual void Alloc(const BufferType* buffer, int64 size) = 0; // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing @@ -247,7 +252,7 @@ class HeapAlgorithm { memory_by_computation) {} // Free de-allocates a previously allocated buffer. - virtual void Free(const HloValue* buffer, int64 size) = 0; + virtual void Free(const BufferType* buffer, int64 size) = 0; // Indicates that a buffer has to be collocated with another buffer. In // addition to Alloc and Free, the heap simulator exposes a concept of buffer @@ -255,7 +260,7 @@ class HeapAlgorithm { // the buffer, it associates the buffer with a previously allocated (or // shared) buffer. Each group of mutually-shared buffers points to a single // SharedGroup instance, which is a shared control block. - virtual void ShareWith(const HloValue* buffer, const HloValue* share_with, + virtual void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) { Alloc(buffer, size); } @@ -269,19 +274,22 @@ class HeapAlgorithm { // this is the absolute minimum size for a given instruction sequence. The // result.chunk_map returned in Finish is always empty, since we only collect // stats, and don't actually compute chunk assignments. -class NoFragmentationStatsHeap : public HeapAlgorithm { +template +class NoFragmentationStatsHeap : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const HloValue* buffer, int64 size) override; + void Alloc(const BufferType* buffer, int64 size) override; void AccountForSubcomputationMemory( const HloInstruction* instruction, int64 alloc_size_by_instruction, const absl::flat_hash_map& memory_by_computation) override; - void Free(const HloValue* buffer, int64 size) override; + void Free(const BufferType* buffer, int64 size) override; Result Finish() override; @@ -336,8 +344,12 @@ class BufferIntervalTree { // alloc/free time. It internally tracks the allocated buffers and their live // intervals; when allocating a buffer, it finds the best-fit free chunk during // its live interval. -class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { +template +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + using Chunk = HeapSimulator::Chunk; + enum Type { kSpatial = 0, kTemporal, @@ -345,7 +357,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // BufferInterval stores a buffer's size and time interval. struct BufferInterval { - const HloValue* buffer; + const BufferType* buffer; int64 size; // Alloc time of the buffer. int64 start; @@ -353,7 +365,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { int64 end; // Colocation buffers that need to be collocated with this one. - std::vector colocations; + std::vector colocations; // True if this buffer needs an allocation. False if it is collocated with // other buffer. @@ -368,10 +380,10 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { Type type = kSpatial); ~GlobalDecreasingSizeBestFitHeap() override {} - void Alloc(const HloValue* buffer, int64 size) override; - void Free(const HloValue* buffer, int64 size) override; + void Alloc(const BufferType* buffer, int64 size) override; + void Free(const BufferType* buffer, int64 size) override; - void ShareWith(const HloValue* buffer, const HloValue* share_with, + void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) override; Result Finish() override; @@ -404,7 +416,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { void CommitChunk(const BufferInterval& buffer_interval, ChunkCandidate chunk_candidate); // Adds the buffer and the chunk to the result chunk map. - virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk); + virtual void AddToChunkMap(const BufferType* buffer, Chunk chunk); // Return a BufferIntervalCompare function that sorts by live ranges. A live // range is defined by the range between the start of the first buffer and the @@ -413,7 +425,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // contiguous. BufferIntervalCompare GetTemporalBufferIntervalCompare() const; - absl::flat_hash_map buffer_intervals_; + absl::flat_hash_map buffer_intervals_; Result result_; BufferIntervalCompare buffer_interval_compare_; BufferIntervalTree interval_tree_; @@ -428,33 +440,37 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // Returns all transitive colocated buffers of this buffer interval. I.e., If // a buffer A is colocated with B and B is colocated with C, this function // returns all three of them. - absl::flat_hash_set GetTransitiveColocations( + absl::flat_hash_set GetTransitiveColocations( const BufferInterval& interval) const; }; // A heap algorithm that chooses the best results from other algorithms added to // it. -class ChooseBestHeapAlgorithm : public HeapAlgorithm { +template +class ChooseBestHeapAlgorithm : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + ChooseBestHeapAlgorithm( - std::unique_ptr>> algorithms) + std::unique_ptr>>> + algorithms) : algorithms_(std::move(*algorithms)) {} ~ChooseBestHeapAlgorithm() override {} - void Alloc(const HloValue* buffer, int64 size) override { + void Alloc(const BufferType* buffer, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->Alloc(buffer, size); } } - void ShareWith(const HloValue* buffer, const HloValue* share_with, + void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->ShareWith(buffer, share_with, size); } } - void Free(const HloValue* buffer, int64 size) override { + void Free(const BufferType* buffer, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->Free(buffer, size); } @@ -463,7 +479,7 @@ class ChooseBestHeapAlgorithm : public HeapAlgorithm { Result Finish() override; private: - std::vector> algorithms_; + std::vector>> algorithms_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b5b711cab4f..8f7668b4965 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -228,7 +228,7 @@ const char kFinish[] = "Finish"; using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. -class HeapCallRecorder : public HeapAlgorithm { +class HeapCallRecorder : public HeapAlgorithm { public: explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} @@ -396,7 +396,7 @@ class HeapSimulatorTracker { std::unique_ptr module_; std::unique_ptr alias_analysis_; CallSequence actual_calls_; - HeapSimulator::Result result_; + HeapSimulator::Result result_; }; class HeapSimulatorTest : public HloTestBase { @@ -976,12 +976,12 @@ class HeapAlgorithmTestBase : public ::testing::Test { class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(NoFragmentationStatsHeapTest, Empty) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; EXPECT_EQ(0, heap.Finish().heap_size); } TEST_F(NoFragmentationStatsHeapTest, Simple) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -994,7 +994,7 @@ TEST_F(NoFragmentationStatsHeapTest, Simple) { } TEST_F(NoFragmentationStatsHeapTest, Mixed) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; heap.Alloc(buffer_a_, 10); // max: A heap.Alloc(buffer_b_, 20); // max: A+B @@ -1013,7 +1013,7 @@ TEST_F(NoFragmentationStatsHeapTest, Mixed) { class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase { protected: class InheritedGlobalDecreasingSizeBestFitHeap - : public GlobalDecreasingSizeBestFitHeap { + : public GlobalDecreasingSizeBestFitHeap { public: InheritedGlobalDecreasingSizeBestFitHeap() : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {} @@ -1048,8 +1048,8 @@ class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase { }; TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); - const HeapSimulator::Result result = heap.Finish(); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(0, result.heap_size); EXPECT_EQ(0, result.chunk_map.size()); } @@ -1068,7 +1068,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { // | | d | // | +-------+ // -----------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 30); heap.Alloc(buffer_c_, 20); @@ -1078,7 +1078,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { heap.Free(buffer_c_, 20); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(100, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); @@ -1107,7 +1107,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { // | | | // | +-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 50); @@ -1117,7 +1117,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { heap.Free(buffer_c_, 50); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(120, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1148,7 +1148,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { // | | | // | +-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 40); @@ -1160,7 +1160,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { heap.Free(buffer_d_, 30); heap.Free(buffer_e_, 50); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(140, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1184,7 +1184,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) { // || |+----+| | // |+--a---++-b--++---c---+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 40); heap.Free(buffer_a_, 40); heap.Alloc(buffer_b_, 20); @@ -1192,7 +1192,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) { heap.ShareWith(buffer_c_, buffer_a_, 40); heap.Free(buffer_c_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(40, result.heap_size); EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1212,7 +1212,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) { // || | | | <--- colocate with a // |+--a---+ +---c---+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 40); heap.Free(buffer_a_, 40); heap.Alloc(buffer_b_, 20); @@ -1221,7 +1221,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) { heap.Free(buffer_c_, 40); heap.Free(buffer_b_, 20); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(60, result.heap_size); EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1242,7 +1242,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) { // | | | // | +-------b-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Free(buffer_a_, 10); heap.Alloc(buffer_b_, 30); @@ -1251,7 +1251,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) { heap.Free(buffer_c_, 10); heap.Free(buffer_b_, 30); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(40, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 960f60fe882..c3a7b3a5c14 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 72 +// Next ID: 74 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -248,6 +248,12 @@ message HloInstructionProto { // RNG algorithm used by kRngBitGenerator. xla.RandomAlgorithm rng_algorithm = 70; + + // The comparison type used for kCompare. + string comparison_type = 72; + + // Specifies if this is a cross-program-prefetch, used by kCopyStart. + bool is_cross_program_prefetch = 73; } // Serialization of HloComputation. @@ -283,6 +289,16 @@ message HloScheduleProto { map sequences = 1; } +enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // The buffers may or may not alias at runtime. + MAY_ALIAS = 1; + // The buffers must alias at runtime. + MUST_ALIAS = 2; +} + message HloInputOutputAliasProto { // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) @@ -304,8 +320,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; - reserved 4; - reserved "kind"; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 384ae272dc1..cf09ddeec27 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -308,6 +308,39 @@ class BufferValueMap { } } + void ComputeInPlaceOperationAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute aliases for in-place operations (e.g. " + "kDynamicUpdateSlice and kScatter)"; + for (const HloPosition& position : value.positions()) { + HloInstruction* instruction = position.instruction; + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + if (position.index == operand_and_output_index.second) { + const HloUse& operand = operand_and_output_index.first; + const HloValue& operand_value = dataflow_.GetUniqueValueAt( + instruction->operand(operand.operand_number), + operand.operand_index); + VLOG(3) << " operand value " << operand_value.ToShortString() + << " aliases."; + aliased_buffers->push_back(GetBufferForValue(operand_value)); + } + } + } + + for (const HloUse& use : value.uses()) { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) { + if (use == operand_and_output_index.first) { + const HloValue& use_value = dataflow_.GetUniqueValueAt( + use.instruction, operand_and_output_index.second); + VLOG(3) << " use value " << use_value.ToShortString() << " aliases."; + aliased_buffers->push_back(GetBufferForValue(use_value)); + } + } + } + } + // Compute and return a vector of buffers that the given value must be // contained in due to HLO aliasing rules. std::vector ComputeAliasedBuffers(const HloValue& value) { @@ -318,6 +351,7 @@ class BufferValueMap { ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); + ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. absl::c_sort(aliased_buffers); aliased_buffers.erase( diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 2666cb0872d..5e94f1d173e 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) { analysis.BufferLivesOut(analysis.buffers()[0]); } +TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape index_shape = ShapeUtil::MakeShape(S32, {}); + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, index_shape, "param2")); + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0)); + auto dynamic_update_slice = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2})); + + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(copy0), + analysis.GetUniqueBufferAt(dynamic_update_slice)); +} + +TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + // Expect negate1 and negate2 to alias with fusion{1} and fusion{2} + // respectively (due to DUS), but not negate0 and fusion{0}. + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + const HloInstruction* negate1 = + module_->entry_computation()->GetInstructionWithName("negate1"); + const HloInstruction* negate2 = + module_->entry_computation()->GetInstructionWithName("negate2"); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate1), + analysis.GetUniqueBufferAt(fusion, {1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate2), + analysis.GetUniqueBufferAt(fusion, {2})); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion, {0})); +} + +TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) { + // CPU and GPU backends may generate fusions with dynamic update slices + // feeding each other. They expect the fusion to not be in-place if that is + // the case. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) + ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion)); +} + TEST_F(HloAliasAnalysisTest, BitcastInterference) { // A bitcast value simultaneously live with its operand should not cause // interference. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 438aa6ff05f..75a6dcdfdd2 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -545,7 +545,7 @@ string HloComputation::ToString( if (options.print_percent()) { s << "%"; } - if (options.print_ids() || !IsEntryComputation()) { + if (options.print_ids()) { // Exclude entry computation's name because it includes and leads to // non-deterministic fingerprint. s << PrintName(name(), options.print_ids()) << " "; @@ -836,8 +836,9 @@ ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const { return program_shape; } -bool HloComputation::Equal(const HloComputation& other, - bool is_layout_sensitive) const { +bool HloComputation::EqualInternal(const HloComputation& other, + bool is_layout_sensitive, + bool ignore_channel_id_values) const { if (this == &other) { return true; } @@ -855,15 +856,21 @@ bool HloComputation::Equal(const HloComputation& other, continue; } visited.emplace(pair); - // TODO(b/123082518): Avoid recursively invoking == because it may + // TODO(b/123082518): Avoid recursively invoking Equal because it may // cause a stack overflow with deeply nested subcomputations. - bool identical_ignoring_operands = pair.first->Identical( - *pair.second, - [](const HloInstruction*, const HloInstruction*) { return true; }, - [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }, - is_layout_sensitive); + auto operands_eq = [](const HloInstruction*, const HloInstruction*) { + return true; + }; + auto comp_eq = [&](const HloComputation* a, const HloComputation* b) { + return a->EqualInternal(*b, is_layout_sensitive, + ignore_channel_id_values); + }; + bool identical_ignoring_operands = + ignore_channel_id_values + ? pair.first->IdenticalIgnoringChannelIdValues( + *pair.second, operands_eq, comp_eq, is_layout_sensitive) + : pair.first->Identical(*pair.second, operands_eq, comp_eq, + is_layout_sensitive); if (!identical_ignoring_operands) { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index d640007886c..1dcf1d9d7d3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -310,7 +310,19 @@ class HloComputation { ProgramShape ComputeProgramShape(bool include_ids = true) const; // Return whether `*this` and `other` are functionally equivalent. - bool Equal(const HloComputation& other, bool is_layout_sensitive) const; + bool Equal(const HloComputation& other, bool is_layout_sensitive) const { + return EqualInternal(other, is_layout_sensitive, + /*ignore_channel_id_values=*/false); + } + + // Same as Equal() but ignores channel ID value mismatches on instructions, as + // long as the two instructions both have channel IDs or neither has a channel + // ID. + bool EqualIgnoringChannelIdValues(const HloComputation& other, + bool is_layout_sensitive) const { + return EqualInternal(other, is_layout_sensitive, + /*ignore_channel_id_values=*/true); + } // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const { @@ -489,6 +501,10 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); + // Internal helper for comparison with different options. + bool EqualInternal(const HloComputation& other, bool is_layout_sensitive, + bool ignore_channel_id_values) const; + // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 72b15db0dcd..939c713fc18 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -486,6 +486,10 @@ Status HloCostAnalysis::HandleReshape(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-training. return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d9085dd7785..f101e3819c9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -113,6 +113,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleDynamicReshape(const HloInstruction* reshape) override; Status HandleAddDependency(const HloInstruction* add_dependency) override; Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 4ba67888409..4aeeb6d27ac 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -92,16 +92,17 @@ StatusOr MakeSliceHlo(HloInstruction* operand, StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, 1, - window, dimension_numbers)); + lhs->shape(), rhs->shape(), feature_group_count, + batch_group_count, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, 1, window, + convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window, dimension_numbers, precision_config)); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 2b17ae3d967..53eeeffb858 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -61,7 +61,8 @@ StatusOr MakeSliceHlo(HloInstruction* operand, // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index a46d20d5808..72899ffe163 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -42,7 +44,45 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// CalculatePostOrderSchedule traverses a module and assign a ordinal to each +// instruction based the postorder dependency. +int64 CalculatePostOrderScheduleHelper( + const HloComputation* comp, int64 start_ordinal, + absl::flat_hash_map* ordinal_map) { + int64 ordinal = start_ordinal; + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal, + ordinal_map); + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(), + ordinal, ordinal_map); + ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(), + ordinal, ordinal_map); + } + // It's possible that in some unit tests the computation graph is not + // flatten (meaning we could have multiple callers for one computation). In + // that case the oridinal_map will see the instruction multiple times. We + // consider that case to be ok as it only shows up in unit tests. + ordinal_map->insert({instruction, ordinal++}); + } + return ordinal; +} +absl::flat_hash_map CalculatePostOrderSchedule( + const HloModule& module) { + absl::flat_hash_map map; + CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map); + return map; +} + +} // namespace using absl::StrAppend; using absl::StrCat; @@ -757,27 +797,35 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } void HloDataflowAnalysis::Propagate() { - std::queue worklist; + using Work = std::pair; + // Avoid duplicating work by preferring work items early in the post order + // schedule. Intuitively, we start from entry parameters and propagate buffers + // updates throughout the module only once. + std::priority_queue, std::greater> worklist; absl::flat_hash_set workset; - auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + auto priority_map = CalculatePostOrderSchedule(module_); + auto add_to_worklist = [&priority_map, &worklist, + &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { - worklist.push(instruction); + worklist.emplace(priority_map[instruction], instruction); } }; - for (HloComputation* computation : module_.computations()) { - for (HloInstruction* instruction : computation->instructions()) { + auto comps = module_.MakeComputationPostOrder(); + for (HloComputation* computation : comps) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { add_to_worklist(instruction); } } VLOG(1) << "SSA_FORM_: " << ssa_form_; while (!worklist.empty()) { - HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.top().second; auto add_to_worklist = [&](HloInstruction* todo) { if (workset.insert(todo).second) { VLOG(1) << " Adding todo : " << todo->name(); - worklist.push(todo); + worklist.emplace(priority_map[todo], todo); } }; worklist.pop(); @@ -1130,69 +1178,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( return true; } -// Given a fusion whose root is a dynamic-update-slice op, determines whether -// the fusion's output buffer can be shared with the buffer of fusion_param, -// which must be a fused parameter of the fusion. -// -// Preconditions: -// -// - fusion's root is a dynamic-update-slice op. -// - fusion_param is a parameter within the fusion. -// -// fusion_param may point to a subelement of the actual parameter instruction if -// the param is a tuple; i.e. fusion_param->index() need not be the empty list. -// -// Returns true if: -// -// * fusion_param is used by the root of dynamic-update-slice as the "base" of -// the update, i.e. the thing being updated, AND -// * all other uses of fusion_param are dynamic-slices that slice the same -// indices as are overwritten in the dynamic-update-slice. -// -// In the case that there are no other uses of fusion_param (last bullet point -// is vacuously true) it's easy to see why an in-place DUS is safe; this is just -// the "natural" implementation of DUS. If there are other users, in-place DUS -// is safe on the assumption that the thread which writes element i of the -// output will be the only one to read element i of fusion_param (via the -// dynamic-slice ops). -static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion, - const HloValue& fusion_param_value) { - auto* root = - Cast(fusion->fused_expression_root()); - auto* fusion_param = fusion_param_value.instruction(); - CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter); - CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation()); +/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) { + return opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kScatter; +} - // fusion_param must be used by the root as the "base" of the - // dynamic-update-slice. The natural way to check this would be - // - // `if (root->operand(0) != fusion_param)` - // - // but we also have to handle the case where the fusion parameter is - // tuple-shaped and we're considering just one element of that tuple, i.e. - // fusion_param.index() != {}. - if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) { - return use.instruction == root; - }) != 1) { - return false; +/*static*/ std::vector> +HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) { + if (IsInPlaceOperation(instruction->opcode())) { + return {{HloUse{instruction, 0, {}}, {}}}; + } else if (instruction->opcode() != HloOpcode::kFusion) { + return {}; } - - // All other uses of fusion_param must be dynamic-slices that slice the same - // indices as are overwritten by the dynamic-update-slice. - for (const HloUse& use : fusion_param_value.uses()) { - auto* user = use.instruction; - if (user == root) { - continue; + std::vector> input_output_pairs; + for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) { + const HloInstruction* hlo_generating_output = + instruction->fused_expression_root(); + for (int64 i = 0; i < indexed_shape.index.size(); ++i) { + if (hlo_generating_output->opcode() == HloOpcode::kTuple) { + hlo_generating_output = + hlo_generating_output->operand(indexed_shape.index[i]); + } else { + CHECK_EQ(i, indexed_shape.index.size() - 1); + } } - // Check that `user` is a dynamic-slice op and has the same slice indices as - // `root`. - auto* ds = DynCast(user); - if (!ds || ds->index_operands() != root->index_operands()) { - return false; + if (IsInPlaceOperation(hlo_generating_output->opcode())) { + ShapeIndex operand_index; + const HloInstruction* fusion_parameter = + hlo_generating_output->operand(0); + while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) { + operand_index.push_front(fusion_parameter->tuple_index()); + fusion_parameter = fusion_parameter->operand(0); + } + + if (fusion_parameter->opcode() == HloOpcode::kParameter) { + input_output_pairs.emplace_back( + HloUse{instruction, fusion_parameter->parameter_number(), + operand_index}, + indexed_shape.index); + } } } - return true; + return input_output_pairs; } bool HloDataflowAnalysis::CanShareOperandBufferWithUser( @@ -1213,24 +1241,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( return false; } - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& fusion_param_value = - GetValueDefinedAt(fusion_param, operand_index); - - // TODO(b/80315712): This code is in a bit of a weird intermediate state - // at the moment. The in-place DUS check really needs to be common to all - // backends, so it runs first. Then we run the backend-specific check if - // provided, or go through the target-independent check if not. - // Unfortunately, the notionally "target-independent" path actually contains - // some target-specific code, so we can't run all of it *in addition* to the - // target-specific function, like the interface documentation says. - if (user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value); + // Must-alias relationship returns true for in-place operations (DUS and DUS + // fusions), regardless of the backend. + for (const auto& operand_and_output_index : + GetInPlaceInputOutputPairs(user)) { + if (operand_and_output_index.second != user_index) { + continue; + } + for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) { + if (use == operand_and_output_index.first) { + return true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index bec592aeb20..ffa307d71dd 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -49,6 +49,9 @@ class HloDataflowAnalysis { // Infrastructure for passing may-alias hints: HLO passes can populate the // may-alias table. If an empty optional is returned, default rules are used. // + // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be + // overriden using backend-specific overrides. + // // The first parameter of the function should be the instruction, the // second parameter should be an operand of the instruction. The third // parameter should be the output index of the instruction. @@ -160,6 +163,15 @@ class HloDataflowAnalysis { const HloModule& module() const { return module_; } + // Returns true if the operation is an in-place operation and its operand 0 + // must alias with the output. + static bool IsInPlaceOperation(HloOpcode opcode); + + // Returns a vector consisting of the HloUse (operand number and shape index) + // and output shape index of the in-place operations within this HLO. + static std::vector> GetInPlaceInputOutputPairs( + HloInstruction* instruction); + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false, diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 1bbbb248bbc..1fa6fe95c40 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1229,10 +1229,10 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( + auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, constant)); + constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); module_->AddEntryComputation(builder.Build()); @@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest, dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); } -TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) { - const char* kModule = R"( - HloModule test - - fused_computation { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30} - ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2) - } - - ENTRY test { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); - auto* fusion = module_->entry_computation()->root_instruction(); - auto* param = module_->entry_computation()->parameter_instruction(0); - - RunAnalysis(); - EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); -} - TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) { const char* kModule = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 66e9e01fc38..acccf7aac9a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1573,9 +1573,9 @@ class OutputBatchIndexToInputIndex { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - // TODO(george): OK what should happen here? - // seems OK to crash though. - index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_); + auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_); + TF_RET_CHECK(start_index.has_value()); + index_vector_[i] = *start_index; } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 9226cd556ff..b91ec9d86ee 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -48,22 +48,26 @@ template struct is_complex_t : absl::disjunction, std::is_same> {}; +namespace detail { +template +using unsigned_promoted_type_t = + std::make_unsigned_t() + std::declval())>; +} + // ToArithmeticSafeType(T t): -// - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed +// - converts `t` to an unsigned integer at least as wide as `int` if T is an // integer, and // - otherwise returns `t` unchanged. // // It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic // in this type to force 2's complement behavior. template ::value && - std::is_signed::value>::type* = nullptr> -typename std::make_unsigned::type ToArithmeticSafeType(T t) { - return static_cast::type>(t); + typename std::enable_if::value>::type* = nullptr> +detail::unsigned_promoted_type_t ToArithmeticSafeType(T t) { + return static_cast>(t); } template ::value || - !std::is_signed::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> T ToArithmeticSafeType(T t) { return std::move(t); } @@ -1153,7 +1157,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 feature_group_index = out_index[output_z_dim] / output_feature_group_size; - const int64 batch_group_index = out_index[output_z_dim]; + const int64 depthwise_multiplier = + batch_group_count > 1 ? output_z_size / input_batch_size : 1; + const int64 batch_group_index = + out_index[output_z_dim] / depthwise_multiplier; ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1214,7 +1221,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = lhs_linear_spatial_index; - lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1229,7 +1235,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc deleted file mode 100644 index 9415e20af7b..00000000000 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" - -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" - -namespace xla { - -namespace { - -StatusOr ReplaceGetSize( - HloInstruction* instr, - DynamicDimensionInference* dynamic_dimension_inference) { - if (instr->opcode() != HloOpcode::kGetDimensionSize) { - return false; - } - HloComputation* computation = instr->parent(); - - TF_ASSIGN_OR_RETURN(auto legal_shape, - ShapeInference::InferGetDimensionSizeShape( - instr->operand(0)->shape(), instr->dimension())); - TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) - << "instr->shape() " << instr->shape().ToString() << " , " - << "legal_shape " << legal_shape.ToString(); - TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32)); - HloInstruction* operand = instr->mutable_operand(0); - int64 dim = instr->dimension(); - HloInstruction* dynamic_size = - dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); - if (dynamic_size != nullptr) { - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); - // The dependency between a instruction and its dynamic dimensions is not - // modeled in the IR. As instr is being replaced by dynamic_size, also tell - // dynamic dimension inference that the instruction is being replaced. - dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( - instr, dynamic_size); - } else { - int32 size = instr->operand(0)->shape().dimensions(dim); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); - dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, - new_instr); - } - return true; -} - -StatusOr ReplaceSetSize(HloInstruction* instr) { - if (instr->opcode() != HloOpcode::kSetDimensionSize) { - return false; - } - - TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( - instr->shape(), instr->operand(0)->shape())) - << "instr->shape() " << instr->shape().ToString() << " , " - << "instruction operand shape " << instr->operand(0)->shape(); - HloInstruction* operand = instr->mutable_operand(0); - - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); - return true; -} - -} // namespace - -StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { - bool changed = false; - HloProto proto; - TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, - DynamicDimensionInference::Run(module)); - *proto.mutable_hlo_module() = module->ToProto(); - // It's important to replace get-dimension-size first before - // set-dimension-size for the case below: - // static_op dynamic_size - // | | - // set-dimension-size // Marks the dimension as dynamic - // | - // get-dimension-size - // - // If we replace set dimension size first, we'd have - // - // static_op - // | - // get-dimension-size - // - // This will get static size of the op, which is incorrect. - for (auto* computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool replaced_get_size, - ReplaceGetSize(instruction, &inference)); - changed = changed || replaced_get_size; - } - } - for (auto* computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); - changed = changed || replaced_set_size; - } - } - return changed; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc deleted file mode 100644 index b1491e96095..00000000000 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -namespace op = xla::testing::opcode_matchers; - -class HloGetDimensionSizeRewriterTest : public HloTestBase { - protected: - HloGetDimensionSizeRewriterTest() {} -}; - -TEST_F(HloGetDimensionSizeRewriterTest, Ok) { - auto module = ParseAndReturnVerifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3,4] parameter(0) - size0 = s32[] get-dimension-size(p), dimensions={0} - size1 = s32[] get-dimension-size(p), dimensions={1} - ROOT mul = s32[] multiply(size0, size1) -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Multiply(op::Constant(), op::Constant())); -} - -TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) { - auto module = ParseAndReturnVerifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3,4] parameter(0) - size0 = s32[] get-dimension-size(p), dimensions={0} - p_copy = s32[3,4] copy(p) - p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} - size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} - ROOT mul = s32[] multiply(size0, size1) -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Multiply(op::Constant(), op::Constant())); -} - -TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { - auto module = ParseAndReturnUnverifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3]{0} parameter(0) - ROOT gds = s64[] get-dimension-size(p), dimensions={0} -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_FALSE(pass.Run(module.get()).ok()); -} - -TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { - auto module = ParseAndReturnUnverifiedModule(R"( -HloModule _ -ENTRY gds { - p = f32[2,5] parameter(0) - ROOT gds = s32[] get-dimension-size(p), dimensions={2} -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_FALSE(pass.Run(module.get()).ok()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d7e8984dee8..164e92ae8e8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1012,6 +1012,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReverse: case HloOpcode::kTupleSelect: case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index e123161720b..34bc30d641f 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { @@ -24,9 +25,10 @@ bool HloInputOutputAliasConfig::OutputHasAlias( return alias_.element(output_index).has_value(); } -Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) { +Status HloInputOutputAliasConfig::SetUpAlias( + const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind must_alias) { TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << "Trying to set up alias at " << output_index.ToString() << " which is an invalid index for shape " @@ -41,7 +43,8 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, param_number, param_index.ToString(), output_index.ToString(), alias_.element(output_index)->parameter_number, alias_.element(output_index)->parameter_index.ToString()); - (*alias_.mutable_element(output_index)) = Alias(param_number, param_index); + (*alias_.mutable_element(output_index)) = + Alias(param_number, param_index, must_alias); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -61,6 +64,11 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } + if (data->must_alias()) { + entry.set_kind(Kind::MUST_ALIAS); + } else { + entry.set_kind(Kind::MAY_ALIAS); + } result.add_entries()->Swap(&entry); } }); @@ -77,8 +85,9 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias; TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } return result; } @@ -93,9 +102,9 @@ string HloInputOutputAliasConfig::ToString() const { ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), alias.parameter_number, - alias.parameter_index.ToString())); + " OutputIndex %s is %saliased with parameter %lld at %s:", + output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-", + alias.parameter_number, alias.parameter_index.ToString())); }); return absl::StrJoin(pieces, "\n"); } @@ -112,6 +121,19 @@ string HloInputOutputAliasConfig::ToShortString() const { return absl::StrJoin(pieces, ", "); } +bool HloInputOutputAliasConfig::ParameterMustAlias( + int64 param_number, const ShapeIndex& param_index) const { + bool result = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index && alias->must_alias()) { + result = true; + } + }); + return result; +} + absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index d5ca28e9387..d5630467783 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -32,22 +32,32 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kMayAlias is one setup at + // compilation time by the user, and has to be respected. A kMustAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kMayAlias, + kMustAlias, + }; // Defines the alias information for a given output buffer. A given output // buffer shape index can refer only to one parameter+index. struct Alias { - Alias(int64 parameter_number, ShapeIndex parameter_index) + Alias(int64 parameter_number, ShapeIndex parameter_index, + AliasKind kind = kMayAlias) : parameter_number(parameter_number), - parameter_index(std::move(parameter_index)) {} + parameter_index(std::move(parameter_index)), + kind(kind) {} int64 parameter_number; ShapeIndex parameter_index; + AliasKind kind; + + bool must_alias() const { return kind == kMustAlias; } std::string ToString() { - if (parameter_index.empty()) { - return absl::StrCat(parameter_number); - } - return absl::StrFormat("(%lld, %s)", parameter_number, - parameter_index.ToString()); + return absl::StrFormat("(%lld, %s, %s)", parameter_number, + parameter_index.ToString(), + kind == kMustAlias ? "must-alias" : "may-alias"); } }; @@ -61,7 +71,8 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, + AliasKind must_alias = kMayAlias); // Returns true if the given parameter is aliased with one of the output // buffers. @@ -92,6 +103,11 @@ class HloInputOutputAliasConfig { absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; + // Returns if the parameter at the given parameter number and parameter + // index must-alias with an output. + bool ParameterMustAlias(int64 param_number, + const ShapeIndex& param_index) const; + using AliasFn = std::function; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 94d53ebe0b1..251261a677f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -167,6 +167,11 @@ StatusOr> HloInstruction::CreateFromProto( absl::Span(fft_length)); break; } + case HloOpcode::kCopyStart: { + instruction = CreateCopyStart(shape, operands(0), + proto.is_cross_program_prefetch()); + break; + } case HloOpcode::kCompare: { // Auto-upgraded from deprecated opcode skips the following. if (!comparison_direction) { @@ -174,8 +179,19 @@ StatusOr> HloInstruction::CreateFromProto( comparison_direction, StringToComparisonDirection(proto.comparison_direction())); } - instruction = - CreateCompare(shape, operands(0), operands(1), *comparison_direction); + auto comparison_type_str = proto.comparison_type(); + if (!comparison_type_str.empty()) { + // If a comparison type is specified, it *must* be valid. + TF_ASSIGN_OR_RETURN(auto comparison_type, + StringToComparisonType(comparison_type_str)); + instruction = CreateCompare(shape, operands(0), operands(1), + *comparison_direction, comparison_type); + } else { + // Allow the specify of comparison type to be optional. + // The comparison type will be determined by the types of the operands. + instruction = CreateCompare(shape, operands(0), operands(1), + *comparison_direction); + } break; } case HloOpcode::kTriangularSolve: { @@ -689,6 +705,17 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateReshape(shape, operands(0), inferred_dimension); break; } + case HloOpcode::kDynamicReshape: { + TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() && + ShapeUtil::ElementsIn(shape) == + ShapeUtil::ElementsIn(operands(0)->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operands(0)->shape()); + const auto& operand_vector = all_operands(); + instruction = CreateDynamicReshape( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1)); + break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (const int64 operand_id : proto.operand_ids()) { @@ -817,7 +844,6 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kCeil: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: - case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kClz: @@ -924,10 +950,18 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, fft_length); } +/* static */ std::unique_ptr HloInstruction::CreateCopyStart( + const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch) { + return absl::make_unique(shape, operand, + is_cross_program_prefetch); +} + /* static */ std::unique_ptr HloInstruction::CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - ComparisonDirection direction) { - return absl::make_unique(shape, lhs, rhs, direction); + ComparisonDirection direction, absl::optional type) { + return absl::make_unique(shape, lhs, rhs, direction, + type); } /* static */ std::unique_ptr @@ -1361,6 +1395,19 @@ HloInstruction::CreateBroadcastSequence( inferred_dimension); } +/* static */ std::unique_ptr +HloInstruction::CreateDynamicReshape( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes) { + CHECK_EQ(ShapeUtil::ElementsIn(shape), + ShapeUtil::ElementsIn(data_operand[0].shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(data_operand[0].shape()); + CHECK_EQ(shape.rank(), dim_sizes.size()); + return absl::make_unique(shape, data_operand, + dim_sizes); +} + /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, absl::Span dimensions) { @@ -1557,6 +1604,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kTranspose: case HloOpcode::kBroadcast: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kMap: case HloOpcode::kSlice: case HloOpcode::kConstant: @@ -1894,6 +1942,56 @@ Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { return Status::OK(); } +bool HloInstruction::IdenticalInternal( + const HloInstruction& other, + const std::function& + eq_operands, + const std::function& + eq_computations, + bool layout_sensitive, bool ignore_channel_id_values) const { + // An instruction is always identical to itself. + if (this == &other) { + return true; + } + + // Identical instruction must have the same opcode, shape, and identical + // operands. + if (opcode() != other.opcode()) { + return false; + } + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { + return false; + } + if (operands().size() != other.operands().size()) { + return false; + } + + // Two AllReduces are Identical if they have the same channel_id. + // Their operands don't have to be Identical. + if (!IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } + } + } + + if (backend_config_ != other.backend_config_) { + return false; + } + + if (ignore_channel_id_values) { + if (auto channel_inst = DynCast(this)) { + return channel_inst->IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations); + } + } + return IdenticalSlowPath(other, eq_computations); +} + void HloInstruction::AppendOperand(HloInstruction* operand) { if (operand->parent() != nullptr) { DCHECK(!operand->parent()->IsMarkedAsDead(operand)) @@ -1995,6 +2093,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReplicaId: case HloOpcode::kRoundNearestAfz: case HloOpcode::kRsqrt: @@ -2800,7 +2899,8 @@ HloInstructionProto HloInstruction::ToProto() const { string HloInstruction::ToCategory() const { if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy || - opcode() == HloOpcode::kReshape) { + opcode() == HloOpcode::kReshape || + opcode() == HloOpcode::kDynamicReshape) { return "data formatting"; } @@ -3021,6 +3121,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandlePad(this); case HloOpcode::kReshape: return visitor->HandleReshape(this); + case HloOpcode::kDynamicReshape: + return visitor->HandleDynamicReshape(this); case HloOpcode::kTranspose: return visitor->HandleTranspose(this); case HloOpcode::kReverse: @@ -3318,6 +3420,11 @@ class HloInstruction::FusionReusesParamElements { // that. value_it = cache->find(&hlo); value_it->second = new_val; + // Fold() minimizes the UseKind value. If it is already minimum, we can + // break the loop early. + if (new_val == UseKind::kReuse) { + break; + } } } return value_it->second; @@ -3939,6 +4046,10 @@ const Shape& HloInstruction::outfeed_shape() const { return Cast(this)->outfeed_shape(); } +Shape* HloInstruction::mutable_outfeed_shape() { + return Cast(this)->mutable_outfeed_shape(); +} + const string& HloInstruction::outfeed_config() const { return Cast(this)->outfeed_config(); } @@ -4077,6 +4188,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } +bool HloInstruction::is_cross_program_prefetch() const { + return Cast(this)->is_cross_program_prefetch(); +} + ComparisonDirection HloInstruction::comparison_direction() const { return Cast(this)->direction(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e29323c25b4..e21ae719e4d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -592,10 +592,17 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + // Creates a copy-start op, indicating whether this is a cross-program + // prefetch or not. + static std::unique_ptr CreateCopyStart( + const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch = false); + // Creates a compare op, performing the comparison specified in direction. static std::unique_ptr CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - Comparison::Direction direction); + Comparison::Direction direction, + absl::optional type = absl::nullopt); static std::unique_ptr CreateTriangularSolve( const Shape& shape, HloInstruction* a, HloInstruction* b, @@ -878,6 +885,14 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, int64 inferred_dimension = -1); + // Creates a dynamic reshape instruction. Similar to reshape but dynamic + // dimensions sizes are provided as additional variadic arguments. + // + // Precondition: dim_sizes.size() == shape.rank() + static std::unique_ptr CreateDynamicReshape( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes); + // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, @@ -1107,41 +1122,23 @@ class HloInstruction { const std::function& eq_computations = std::equal_to(), bool layout_sensitive = true) const { - // An instruction is always identical to itself. - if (this == &other) { - return true; - } + return IdenticalInternal(other, eq_operands, eq_computations, + layout_sensitive, + /*ignore_channel_id_values=*/false); + } - // Identical instruction must have the same opcode, shape, and identical - // operands. - if (opcode() != other.opcode()) { - return false; - } - if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) - : ShapeUtil::Compatible(shape(), other.shape()))) { - return false; - } - if (operands().size() != other.operands().size()) { - return false; - } - - // Two AllReduces are Identical if they have the same channel_id. - // Their operands don't have to be Identical. - if (!IsCrossModuleAllReduce()) { - // Use an explicit loop rather than ContainerEquals, because copying - // around std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; - } - } - } - - if (backend_config_ != other.backend_config_) { - return false; - } - - return IdenticalSlowPath(other, eq_computations); + // Same as Identical() but ignores channel ID value mismatches, as long as + // both have channel IDs or neither has a channel ID. + bool IdenticalIgnoringChannelIdValues( + const HloInstruction& other, + const std::function& + eq_operands = std::equal_to(), + const std::function& + eq_computations = std::equal_to(), + bool layout_sensitive = true) const { + return IdenticalInternal(other, eq_operands, eq_computations, + layout_sensitive, + /*ignore_channel_id_values=*/true); } // Generates a hash value of an HLO instruction. Hash considers @@ -1772,6 +1769,9 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; + // Returns the mutable shape for the Outfeed instruction. + Shape* mutable_outfeed_shape(); + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; @@ -1856,6 +1856,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). + bool is_cross_program_prefetch() const; + // Delegates to HloCompareInstruction::direction(). ComparisonDirection comparison_direction() const; @@ -1944,6 +1947,14 @@ class HloInstruction { private: friend class HloComputation; + bool IdenticalInternal( + const HloInstruction& other, + const std::function& + eq_operands, + const std::function& + eq_computations, + bool layout_sensitive, bool ignore_channel_id_values) const; + // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 3d34fa03a80..c4c31dba9a4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -204,12 +204,54 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } -HloCompareInstruction::HloCompareInstruction(const Shape& shape, - HloInstruction* lhs, - HloInstruction* rhs, - ComparisonDirection direction) +HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape, + HloInstruction* operand, + bool is_cross_program_prefetch) + : HloInstruction(HloOpcode::kCopyStart, shape), + is_cross_program_prefetch_(is_cross_program_prefetch) { + AppendOperand(operand); +} + +HloInstructionProto HloCopyStartInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_is_cross_program_prefetch(is_cross_program_prefetch_); + return proto; +} + +std::vector HloCopyStartInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + if (is_cross_program_prefetch()) { + result.push_back("is_cross_program_prefetch=true"); + } + return result; +} + +bool HloCopyStartInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return is_cross_program_prefetch() == + casted_other.is_cross_program_prefetch(); +} + +std::unique_ptr +HloCopyStartInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], is_cross_program_prefetch()); +} + +HloCompareInstruction::HloCompareInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction, absl::optional type) : HloInstruction(HloOpcode::kCompare, shape), - compare_(direction, lhs->shape().element_type()) { + compare_(direction, type ? (*type) + : Comparison::DefaultComparisonType( + lhs->shape().element_type())) { AppendOperand(lhs); AppendOperand(rhs); } @@ -218,12 +260,21 @@ HloInstructionProto HloCompareInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_comparison_direction( ComparisonDirectionToString(compare_.GetDirection())); + proto.set_comparison_type(ComparisonTypeToString(compare_.GetType())); return proto; } std::vector HloCompareInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("direction=", ComparisonDirectionToString(direction()))}; + std::vector result; + result.push_back( + StrCat("direction=", ComparisonDirectionToString(direction()))); + if (compare_.GetType() != + Comparison::DefaultComparisonType(operand(0)->shape().element_type())) { + result.push_back( + StrCat("type=", ComparisonTypeToString(compare_.GetType()))); + } + return result; } bool HloCompareInstruction::IdenticalSlowPath( @@ -238,8 +289,8 @@ std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return absl::make_unique(shape, new_operands[0], - new_operands[1], direction()); + return absl::make_unique( + shape, new_operands[0], new_operands[1], direction(), type()); } namespace { @@ -396,7 +447,10 @@ std::vector HloChannelInstruction::ExtraAttributesToStringImpl( bool HloChannelInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - /*eq_computations*/) const { + eq_computations) const { + if (!IdenticalSlowPathIgnoringChannelIdValues(other, eq_computations)) { + return false; + } const auto& casted_other = static_cast(other); return channel_id() == casted_other.channel_id(); } @@ -424,7 +478,7 @@ std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( return attrs; } -bool HloSendRecvInstruction::IdenticalSlowPath( +bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { @@ -545,13 +599,14 @@ std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( return result; } -bool HloCollectiveInstruction::IdenticalSlowPath( +bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && constrain_layout() == casted_other.constrain_layout() && absl::c_equal(replica_groups(), casted_other.replica_groups(), [](const ReplicaGroup& a, const ReplicaGroup& b) { @@ -594,12 +649,13 @@ HloInstructionProto HloAllGatherInstruction::ToProto() const { return proto; } -bool HloAllGatherInstruction::IdenticalSlowPath( +bool HloAllGatherInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && all_gather_dimension_ == casted_other.all_gather_dimension() && use_global_device_ids() == casted_other.use_global_device_ids(); } @@ -640,12 +696,13 @@ std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( return result; } -bool HloAllReduceInstruction::IdenticalSlowPath( +bool HloAllReduceInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && constrain_layout() == casted_other.constrain_layout() && use_global_device_ids() == casted_other.use_global_device_ids() && eq_computations(to_apply(), casted_other.to_apply()); @@ -696,12 +753,13 @@ std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( return result; } -bool HloAllToAllInstruction::IdenticalSlowPath( +bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && split_dimension_ == casted_other.split_dimension(); } @@ -737,7 +795,7 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( return result; } -bool HloCollectivePermuteInstruction::IdenticalSlowPath( +bool HloCollectivePermuteInstruction::IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const { @@ -746,7 +804,8 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( } const auto& casted_other = static_cast(other); - return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + return HloChannelInstruction::IdenticalSlowPathIgnoringChannelIdValues( + other, eq_computations) && absl::c_equal(source_target_pairs(), casted_other.source_target_pairs(), [](const std::pair& a, @@ -1017,6 +1076,25 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( dimensions()); } +HloDynamicReshapeInstruction::HloDynamicReshapeInstruction( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes) + : HloInstruction(HloOpcode::kDynamicReshape, shape) { + AppendOperand(data_operand); + for (auto operand : dim_sizes) { + AppendOperand(operand); + } +} + +std::unique_ptr +HloDynamicReshapeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_GE(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1)); +} + HloReshapeInstruction::HloReshapeInstruction(const Shape& shape, HloInstruction* operand, int64 inferred_dimension) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 51317b32bd0..821849bb02f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -132,12 +132,36 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloCopyStartInstruction : public HloInstruction { + public: + explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch); + + bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + bool is_cross_program_prefetch_; +}; + class HloCompareInstruction : public HloInstruction { public: explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - ComparisonDirection direction); + ComparisonDirection direction, + absl::optional type); ComparisonDirection direction() const { return compare_.GetDirection(); } + Comparison::Type type() const { return compare_.GetType(); } HloInstructionProto ToProto() const override; private: @@ -220,6 +244,15 @@ class HloChannelInstruction : public HloInstruction { absl::optional channel_id() const { return channel_id_; } void set_channel_id(const absl::optional& channel_id); + // Whether this instruction is identical to `other` except for the values of + // channel IDs, as long as both have channel IDs or neither has a channel ID. + virtual bool IdenticalSlowPathIgnoringChannelIdValues( + const HloInstruction& other, + const std::function& + eq_computations) const { + return channel_id_.has_value() == other.channel_id().has_value(); + } + protected: explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, const absl::optional& channel_id); @@ -228,10 +261,13 @@ class HloChannelInstruction : public HloInstruction { std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; + + // Do not override IdenticalSlowPath(). Override + // IdenticalSlowPathIgnoringChannelIdValues() instead. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const override; + eq_computations) const final; absl::optional channel_id_; }; @@ -251,7 +287,7 @@ class HloSendRecvInstruction : public HloChannelInstruction { private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -339,7 +375,7 @@ class HloCollectiveInstruction : public HloChannelInstruction { std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -366,7 +402,7 @@ class HloAllGatherInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -410,7 +446,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -447,7 +483,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { HloInstructionProto ToProto() const override; private: - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -477,7 +513,7 @@ class HloCollectivePermuteInstruction : public HloChannelInstruction { private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; - bool IdenticalSlowPath( + bool IdenticalSlowPathIgnoringChannelIdValues( const HloInstruction& other, const std::function& eq_computations) const override; @@ -677,6 +713,25 @@ class HloBroadcastInstruction : public HloInstruction { std::vector dimensions_; }; +class HloDynamicReshapeInstruction : public HloInstruction { + public: + explicit HloDynamicReshapeInstruction( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes); + + // Returns the input dim sizes dimensions, which is operands[1:] + absl::Span dim_sizes() const { + return absl::MakeSpan(operands()).subspan(1, operand_count()); + } + + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + // Returns the input dim size dimension, which is operands[1+i] + HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; } +}; + class HloReshapeInstruction : public HloInstruction { public: explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, @@ -1139,6 +1194,8 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { return outfeed_shape_; } + // Returns the mutable shape for the Outfeed instruction. + Shape* mutable_outfeed_shape() { return &outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } void set_outfeed_config(const string& config) { outfeed_config_ = config; } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 5502665e886..749193a83ef 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -281,6 +281,7 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(last_tile_dim_replicate); #undef KEYWORD @@ -495,6 +496,8 @@ string TokKindToString(TokKind kind) { return "kw_maximal"; case TokKind::kw_replicated: return "kw_replicated"; + case TokKind::kw_last_tile_dim_replicate: + return "kw_last_tile_dim_replicate"; case TokKind::kw_nan: return "kw_nan"; case TokKind::kw_inf: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 6a59f180ad8..b8c7debaab4 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -61,6 +61,7 @@ enum class TokKind { kw_false, kw_maximal, kw_replicated, + kw_last_tile_dim_replicate, kw_nan, kw_inf, diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index cb5cbd05d65..9c6509d8b73 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -276,10 +276,10 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) { /*element_size_in_bits=*/0, /*memory_space=*/2); auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0"); - auto copy_start = HloInstruction::CreateUnary( + auto copy_start = HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape( {shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, p0.get()); + p0.get()); auto copy_done = HloInstruction::CreateUnary( shape_memspace2, HloOpcode::kCopyDone, copy_start.get()); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 8ee8d332aff..076e31dc8eb 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -50,9 +50,9 @@ int64 PeakMemoryUseOfEntryComputation( HloComputation* computation = module->entry_computation(); const HloInstructionSequence& sequence = schedule.sequence(computation); - return HeapSimulator::Run(absl::make_unique(), - *computation, sequence, *alias_analysis, - size_function) + return HeapSimulator::Run( + absl::make_unique>(), + *computation, sequence, *alias_analysis, size_function) .ValueOrDie() .heap_size; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index eaed707607d..8158d198799 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -51,12 +51,14 @@ string HloModuleConfig::compilation_cache_key() const { string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; - for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); + if (entry_computation_layout_.has_value()) { + for (const ShapeLayout& param_layout : + entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", + entry_computation_layout_->result_shape().SerializeAsString()); } - StrAppend(&key, absl::StrJoin(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 1625d0bbae4..b50c7d9a584 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -123,6 +123,7 @@ namespace xla { V(kRemainder, "remainder", 2) \ V(kReplicaId, "replica-id", 0) \ V(kReshape, "reshape", 1) \ + V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kReverse, "reverse", 1) \ V(kRng, "rng", kHloOpcodeIsVariadic) \ V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 136e6702b21..cceb60a70e9 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kCustomCall: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kDynamicReshape: case HloOpcode::kFusion: case HloOpcode::kMap: case HloOpcode::kReduce: diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 0530062c43b..e2bbda3a607 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -194,6 +194,7 @@ class HloParserImpl : public HloParser { kBracedHloComputationList, kFftType, kComparisonDirection, + kComparisonType, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -327,6 +328,7 @@ class HloParserImpl : public HloParser { bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseComparisonDirection(ComparisonDirection* result); + bool ParseComparisonType(Comparison::Type* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParseRandomAlgorithm(RandomAlgorithm* result); @@ -552,33 +554,39 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) { return false; } - if (lexer_.GetKind() != TokKind::kLparen) { - // Short form: "{0}: 0", output index "{}" is assumed. - int64 param_num; - ParseInt64(¶m_num); - data->emplace(std::piecewise_construct, std::forward_as_tuple(out), - std::forward_as_tuple(param_num, ShapeIndex{})); - } else { - // Long form: "{0}: (0, {0})", output index is explicitly specified. - if (!ParseToken(TokKind::kLparen, errmsg)) { - return false; - } - int64 param_num; - ParseInt64(¶m_num); - if (!ParseToken(TokKind::kComma, errmsg)) { - return false; - } - ShapeIndex param_idx; - if (!ParseShapeIndex(¶m_idx)) { - return false; - } - data->emplace(std::piecewise_construct, std::forward_as_tuple(out), - std::forward_as_tuple(param_num, param_idx)); - if (!ParseToken(TokKind::kRparen, errmsg)) { - return false; + if (!ParseToken(TokKind::kLparen, errmsg)) { + return false; + } + int64 param_num; + ParseInt64(¶m_num); + if (!ParseToken(TokKind::kComma, errmsg)) { + return false; + } + ShapeIndex param_idx; + if (!ParseShapeIndex(¶m_idx)) { + return false; + } + + HloInputOutputAliasConfig::AliasKind alias_kind = + HloInputOutputAliasConfig::kMayAlias; + if (EatIfPresent(TokKind::kComma)) { + std::string type; + ParseName(&type); + if (type == "must-alias") { + alias_kind = HloInputOutputAliasConfig::kMustAlias; + } else if (type == "may-alias") { + alias_kind = HloInputOutputAliasConfig::kMayAlias; + } else { + return TokenError("Unexpected aliasing kind; expected SYSTEM or USER"); } } + data->emplace(std::piecewise_construct, std::forward_as_tuple(out), + std::forward_as_tuple(param_num, param_idx, alias_kind)); + if (!ParseToken(TokKind::kRparen, errmsg)) { + return false; + } + if (!EatIfPresent(TokKind::kComma)) { break; } @@ -624,8 +632,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module) { if (aliasing_data) { HloInputOutputAliasConfig alias_config(module->result_shape()); for (auto& p : *aliasing_data) { - Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number, - p.second.parameter_index); + Status st = + alias_config.SetUpAlias(p.first, p.second.parameter_number, + p.second.parameter_index, p.second.kind); if (!st.ok()) { return TokenError(st.error_message()); } @@ -874,7 +883,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kClz: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: - case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kExp: @@ -1082,6 +1090,20 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, } break; } + case HloOpcode::kCopyStart: { + // If the is_cross_program_prefetch attribute is not present then default + // to false. + optional is_cross_program_prefetch = false; + attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool, + &is_cross_program_prefetch}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCopyStart( + shape, operands[0], *is_cross_program_prefetch)); + break; + } case HloOpcode::kReplicaId: { if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -1099,6 +1121,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, builder->AddInstruction(HloInstruction::CreatePartitionId()); break; } + case HloOpcode::kDynamicReshape: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateDynamicReshape( + shape, operands[0], + absl::Span(operands).subspan(1))); + break; + } case HloOpcode::kReshape: { optional inferred_dimension; attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64, @@ -1355,14 +1387,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, } case HloOpcode::kCompare: { optional direction; + optional type; attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection, &direction}; + attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCompare( - shape, operands[0], operands[1], *direction)); + shape, operands[0], operands[1], *direction, type)); break; } case HloOpcode::kCholesky: { @@ -2129,6 +2163,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; + bool last_tile_dim_replicate = false; std::vector devices; std::vector tile_assignment_dimensions; while (lexer_.GetKind() != TokKind::kRbrace) { @@ -2180,6 +2215,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, } break; } + case TokKind::kw_last_tile_dim_replicate: + last_tile_dim_replicate = true; + lexer_.Lex(); + break; case TokKind::kRbrace: break; default: @@ -2218,6 +2257,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, for (int64 device : devices) { sharding->add_tile_assignment_devices(device); } + sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate); } lexer_.Lex(); @@ -3005,6 +3045,14 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kComparisonType: { + Comparison::Type result; + if (!ParseComparisonType(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kEnum: { if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects an enumeration value"); @@ -4132,6 +4180,21 @@ bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) { return true; } +bool HloParserImpl::ParseComparisonType(Comparison::Type* result) { + VLOG(1) << "ParseComparisonType"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects comparison type"); + } + std::string val = lexer_.GetStrVal(); + auto status_or_result = StringToComparisonType(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects comparison type but sees: %s", val)); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(3) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 484578e5e0e..620e67c3a2f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -230,7 +230,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} - %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated} + %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated} ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } @@ -318,7 +318,7 @@ R"(HloModule CopyStartAndCopyDone_module ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) { %v1 = f32[] parameter(0) - %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1) + %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1) %v2 = f32[2,3]{1,0:S(1)} parameter(1) %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2) @@ -512,7 +512,7 @@ R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -2399,7 +2399,7 @@ ENTRY c2 { TEST_F(HloParserTest, SimpleAliasing) { const string original = R"( -HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) } +HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) } ENTRY entry { %p = (f32[], f32[]) parameter(0) @@ -2413,42 +2413,13 @@ ENTRY entry { std::unique_ptr parsed_module = module.ConsumeValueOrDie(); EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}), ShapeIndex{0}); + + EXPECT_TRUE( + parsed_module->input_output_alias_config().ParameterMustAlias(0, {0})); EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}), ShapeIndex{1}); -} - -TEST_F(HloParserTest, SimpleAliasingShortForm) { - const string original = R"( -HloModule Module, input_output_alias={ {0}: 0, {1}: 1 } - -ENTRY entry { - %p0 = f32[] parameter(0) - %p1 = f32[] parameter(1) - ROOT %out = (f32[], f32[]) tuple(%p0, %p1) -} - )"; - auto module = ParseAndReturnVerifiedModule(original); - TF_ASSERT_OK(module.status()); - std::unique_ptr parsed_module = module.ConsumeValueOrDie(); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {}), - ShapeIndex{0}); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(1, {}), - ShapeIndex{1}); -} - -TEST_F(HloParserTest, SimpleAliasingShortFormError) { - const string original = R"( -HloModule Module, input_output_alias={ {0}: A, {1}: 1 } - -ENTRY entry { - %p0 = f32[] parameter(0) - %p1 = f32[] parameter(1) - ROOT %out = (f32[], f32[]) tuple(%p0, %p1) -} - )"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "expects integer"); + EXPECT_FALSE( + parsed_module->input_output_alias_config().ParameterMustAlias(0, {1})); } TEST_F(HloParserTest, NestedAliasing) { @@ -2626,6 +2597,21 @@ TEST_F(HloParserTest, ParseSharding) { EXPECT_EQ(sharding.ToString(), original); } +TEST_F(HloParserTest, ParseShardingPartialReplication) { + const string original = "{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); + Array group_tiling({2}); + group_tiling(0) = 0; + group_tiling(1) = 1; + std::vector group0_members({0, 1}); + std::vector group1_members({2, 3}); + EXPECT_EQ( + HloSharding::PartialTile(group_tiling, {group0_members, group1_members}) + .ToString(), + original); +} + TEST_F(HloParserTest, ParseFrontendAttributes) { const string original = R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index a22a394c6a4..1de231a9a86 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -43,11 +43,12 @@ class HloPassFix : public Pass { while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; - VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + VLOG(3) << Pass::name() << " iteration " << iteration_count + << " changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == kLimit) { - VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " - "exiting fixed point loop."; + VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" + << Pass::name() << "' exiting fixed point loop."; // Return false in case this is fixed point is nested. return false; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index b07ab10827a..3b7b0b61f0a 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -69,6 +69,9 @@ StatusOr HloPassPipeline::RunPassesInternal( } TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); changed |= pass_changed; + if (pass_changed) { + VLOG(3) << " Pass caused changes" << pass->name(); + } TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name)); last_pass_name = string(pass_name); if (!pass->IsPassPipeline()) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 83130108dd7..3a5e7ca6f40 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -259,9 +259,15 @@ StatusOr> HloRunner::ExecuteReplicated( return ExecuteReplicated(executable.get(), options, device_assignment); } -StatusOr> HloRunner::ExecuteReplicated( - Executable* executable, const ReplicatedExecuteOptions& options, - DeviceAssignment* device_assignment, ExecutionProfile* profile) { +StatusOr> HloRunner::ExecuteReplicatedImpl( + std::function>( + const std::vector&, + const std::vector>&)> + execution_helper, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment) { std::vector> streams; std::vector service_run_options; @@ -269,12 +275,19 @@ StatusOr> HloRunner::ExecuteReplicated( // This reserve() call is necessary for correctness, because // argument_buffer_ptrs contains pointers into the elements of // argument_buffers. - argument_buffers.reserve(options.num_replicas * options.arguments.size()); + const int64 total_argument_count = [&]() { + int64 total = 0; + for (int64 i = 0; i < options.num_replicas; ++i) { + total += argument_count_provider(i); + } + return total; + }(); + argument_buffers.reserve(total_argument_count); // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are // no arguments. - std::vector argument_buffer_ptrs( - options.num_replicas * options.arguments.size() + 1); + std::vector argument_buffer_ptrs(total_argument_count + + 1); std::vector> argument_buffer_slices; int64 index = 0; RunId run_id; @@ -288,7 +301,10 @@ StatusOr> HloRunner::ExecuteReplicated( device, streams.back().get(), device_assignment, run_id)); // Copy arguments to device. - for (const Literal* argument : options.arguments) { + const int64 argument_count = argument_count_provider(i); + for (int64 arg_index = 0; arg_index < argument_count; arg_index++) { + const Literal* const argument = argument_provider(i, arg_index); + TF_RET_CHECK(argument != nullptr); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer argument_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( @@ -299,8 +315,7 @@ StatusOr> HloRunner::ExecuteReplicated( argument_buffer_ptrs[index++] = &argument_buffers.back(); } argument_buffer_slices.emplace_back( - &argument_buffer_ptrs[index - options.arguments.size()], - options.arguments.size()); + &argument_buffer_ptrs[index - argument_count], argument_count); } std::unique_ptr pool; @@ -355,39 +370,9 @@ StatusOr> HloRunner::ExecuteReplicated( } LOG(INFO) << "Replicated execution started"; - std::vector results; - if (!options.use_threads) { - TF_ASSIGN_OR_RETURN(results, - executable->ExecuteOnStreams(service_run_options, - argument_buffer_slices)); - } else { - tensorflow::mutex mutex; - std::vector> thread_results( - options.num_replicas); - { - LOG(INFO) << "Creating thread pool for " << options.num_replicas - << " replicas"; - tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), - "replicas", options.num_replicas); - for (int64 i = 0; i < options.num_replicas; ++i) { - pool.Schedule([&, i] { - auto result = executable->ExecuteOnStream( - &service_run_options[i], argument_buffer_slices[i], nullptr); - tensorflow::mutex_lock lock(mutex); - thread_results[i] = std::move(result); - }); - } - - // Note: the thread pool destructor guarantees it completes all work - // before we leave this scope. - } - for (auto& thread_result : thread_results) { - if (!thread_result.ok()) { - return thread_result.status(); - } - results.push_back(std::move(thread_result).ValueOrDie()); - } - } + TF_ASSIGN_OR_RETURN( + std::vector results, + execution_helper(service_run_options, argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; std::vector exec_results; @@ -401,6 +386,104 @@ StatusOr> HloRunner::ExecuteReplicated( return std::move(exec_results); } +StatusOr> HloRunner::ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile) { + return ExecuteReplicatedImpl( + [&](const std::vector& service_run_options, + const std::vector>& + argument_buffer_slices) + -> StatusOr> { + std::vector results; + if (!options.use_threads) { + TF_ASSIGN_OR_RETURN( + results, executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + } else { + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool( + tensorflow::Env::Default(), "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + pool.Schedule([&, i] { + auto result = executable->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], + nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + } + return results; + }, + [&](int64 replica) { return options.arguments.size(); }, + [&](int64 replica, int64 index) { return options.arguments[index]; }, + options, device_assignment); +} + +StatusOr> HloRunner::ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options) { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + return ExecuteReplicatedImpl( + [&](const std::vector& service_run_options, + const std::vector>& + argument_buffer_slices) + -> StatusOr> { + TF_RET_CHECK(options.use_threads); + std::vector results; + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + for (const auto& arg : argument_buffer_slices[i]) { + TF_RET_CHECK(arg != nullptr); + } + pool.Schedule([&, i] { + auto result = executable_provider(i)->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + return results; + }, + argument_count_provider, argument_provider, options, &device_assignment); +} + StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 7e8b301ab54..733bb8bff54 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -176,6 +176,17 @@ class HloRunner { Executable* executable, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); + // Same as above, but with different reusable Executables. This may update the + // profile information in *executables. + // + // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, + // since we've already compiled the Executable. + StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options); + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -193,6 +204,17 @@ class HloRunner { int64 device, se::Stream* stream, DeviceAssignment* device_assignment, RunId run_id); + // Common implementation code for ExecuteReplicated() above. + StatusOr> ExecuteReplicatedImpl( + std::function>( + const std::vector&, + const std::vector>&)> + execution_helper, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment); + std::unique_ptr backend_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index b0a03707efb..4244cdaceea 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,47 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(assignment); } +HloSharding HloSharding::PartialTile( + const Array& group_tile_assignment, + absl::Span> replication_groups) { + auto new_tile_dims = group_tile_assignment.dimensions(); + new_tile_dims.push_back(replication_groups[0].size()); + auto new_tile_assignment = Array(new_tile_dims); + new_tile_assignment.Each([&](absl::Span indices, int64* device) { + std::vector group_index(indices.begin(), indices.end()); + group_index.pop_back(); + int64 group = group_tile_assignment(group_index); + *device = replication_groups[group][indices.back()]; + }); + return PartialTile(new_tile_assignment); +} + +HloSharding HloSharding::PartialTile( + const Array& tile_assignment_last_dim_replicate) { + std::vector> sorted_groups( + tile_assignment_last_dim_replicate.num_elements() / + tile_assignment_last_dim_replicate.dimensions().back()); + auto get_group_id = [&](absl::Span indices) { + int64 group_id = 0; + for (int64 i = 0; i < indices.size() - 1; ++i) { + group_id *= tile_assignment_last_dim_replicate.dim(i); + group_id += indices[i]; + } + return group_id; + }; + tile_assignment_last_dim_replicate.Each( + [&](absl::Span indices, const int64 device) { + sorted_groups[get_group_id(indices)].insert(device); + }); + Array sorted_tile(tile_assignment_last_dim_replicate.dimensions()); + sorted_tile.Each([&](absl::Span indices, int64* device) { + auto begin = sorted_groups[get_group_id(indices)].begin(); + *device = *begin; + sorted_groups[get_group_id(indices)].erase(begin); + }); + return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true); +} + HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { std::vector flattened_list; flattened_list.reserve(sub_shardings.leaf_count()); @@ -101,8 +142,10 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } - return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", - StrJoin(tile_assignment_, ","), "}"); + return StrCat( + "{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", + StrJoin(tile_assignment_, ","), + replicate_on_last_tile_dim_ ? " last_tile_dim_replicate}" : "}"); } bool HloSharding::UsesDevice(int64 device) const { @@ -148,6 +191,9 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { } }); CHECK(!ret_index.empty()); + if (replicate_on_last_tile_dim_) { + ret_index.pop_back(); + } return ret_index; } @@ -157,6 +203,12 @@ int64 HloSharding::DeviceForTileIndex(absl::Span index) const { if (maximal_) { return *tile_assignment_.begin(); } + if (replicate_on_last_tile_dim_ && + index.size() < tile_assignment().num_dimensions()) { + std::vector first_replicated_index(index.begin(), index.end()); + first_replicated_index.push_back(0); + return tile_assignment_(first_replicated_index); + } return tile_assignment_(index); } @@ -167,8 +219,11 @@ std::vector HloSharding::TileOffsetForDevice(const Shape& shape, if (maximal_) { return std::vector(shape.dimensions_size(), 0); } - - CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + if (replicate_on_last_tile_dim_) { + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1); + } else { + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + } std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { const int64 shape_dim = shape.dimensions(i); @@ -187,7 +242,8 @@ std::vector HloSharding::TileLimitForDevice(const Shape& shape, shape.dimensions().end()); } - CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + CHECK_EQ(shape.dimensions_size() + (ReplicateOnLastTileDim() ? 1 : 0), + tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { const int64 shape_dim = shape.dimensions(i); @@ -341,8 +397,10 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile assignment tensor must have the same rank as the input. - if (shape.rank() != tile_assignment_.num_dimensions()) { + // The tile assignment tensor must have the same rank as the input, or input + // rank + 1 for replicate_on_last_tile_dim_. + if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) != + tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", @@ -403,7 +461,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(tile_assignment); + return proto.replicate_on_last_tile_dim() ? PartialTile(tile_assignment) + : HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -429,6 +488,7 @@ OpSharding HloSharding::ToProto() const { result.set_type(OpSharding::MAXIMAL); } else { result.set_type(OpSharding::OTHER); + result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim()); } return result; } @@ -464,6 +524,17 @@ Shape HloSharding::TileShape(const Shape& shape, int64 device) const { return result_shape; } +int64 HloSharding::NumTiles() const { + if (IsTileMaximal()) { + return 1; + } + if (ReplicateOnLastTileDim()) { + return tile_assignment().num_elements() / + tile_assignment().dimensions().back(); + } + return tile_assignment().num_elements(); +} + HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); @@ -516,6 +587,9 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } + if (replicate_on_last_tile_dim_) { + h = tensorflow::Hash64Combine(h, std::hash{}(1)); + } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 20fa7232e65..e7ba2bc0680 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -54,6 +54,19 @@ class HloSharding { return HloSharding(tile_assignment); } + // Creates a new sharding where data is replicated within each replication + // group, and sharded across replication groups according to + // group_tile_assignment. Replication group members will be sorted. + static HloSharding PartialTile( + const Array& group_tile_assignment, + absl::Span> replication_groups); + + // Creates a partially replicated tiled sharding with device-level tile + // assignment, where the last dimension is the additional replication + // dimension. Replication group members will be sorted. + static HloSharding PartialTile( + const Array& tile_assignment_last_dim_replicate); + // Creates a new sharding which splits a one-dimensional input shape into // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); @@ -115,6 +128,11 @@ class HloSharding { }); } + // Returns if the sharding has partial replication and partial sharding. If + // true, data is sharded according to other dimensions of tile_assignment(), + // but replicated across devices along the last dimension. + bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; } + // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; @@ -132,6 +150,10 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. + // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it + // returns the first device in that replicated subgroup; otherwise, + // index.size() should be the same as tile_assignment()'s rank and specifies + // the member of the replication subgroup. // REQUIRES: !IsTuple() int64 DeviceForTileIndex(absl::Span index) const; @@ -188,7 +210,8 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && tile_assignment_ == other.tile_assignment_ && - tuple_elements_ == other.tuple_elements_; + tuple_elements_ == other.tuple_elements_ && + replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } @@ -220,12 +243,17 @@ class HloSharding { // REQUIRES: !IsTuple() Shape TileShape(const Shape& shape, int64 device) const; + // Gets the number of tiles. If it has partial replication, this will not + // equal the device count. + int64 NumTiles() const; + private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_assignment_({0}) {} + tile_assignment_({0}), + replicate_on_last_tile_dim_(false) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning // -1: the id of the host @@ -236,18 +264,22 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_assignment_({1}, device_id) {} - explicit HloSharding(const Array& tile_assignment) + tile_assignment_({1}, device_id), + replicate_on_last_tile_dim_(false) {} + explicit HloSharding(const Array& tile_assignment, + bool replicate_on_last_tile_dim = false) : replicated_(false), maximal_(false), tuple_(false), - tile_assignment_(tile_assignment) {} + tile_assignment_(tile_assignment), + replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {} explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), tile_assignment_({0}), - tuple_elements_(tuple_shardings) {} + tuple_elements_(tuple_shardings), + replicate_on_last_tile_dim_(false) {} // Checks that the number of elements in tuple_elements_ is consistent with // the tuple shape passes as argument. @@ -283,6 +315,11 @@ class HloSharding { // present for the root. This is a flattened list of all the leaf shardings in // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector tuple_elements_; + // This flag is to support partial replication and partial sharding. If it is + // true, tile_assignment_ will have an extra dimension in addition to the data + // shape rank, and the added last dimension represents the subgroups of + // replications, i.e., elements in slice [..., :] will be replicated. + bool replicate_on_last_tile_dim_; }; std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 94c348cdeaa..da4e3d61a81 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include +#include #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/array.h" @@ -105,21 +106,28 @@ HloSharding TransposeSharding(const HloSharding& sharding, if (sharding.IsTileMaximal()) { return sharding; } - const int64 rank = dimensions.size(); + auto perm_dimensions = dimensions; + if (sharding.ReplicateOnLastTileDim() && + dimensions.size() < sharding.tile_assignment().num_dimensions()) { + perm_dimensions.push_back(dimensions.size()); + } + const int64 rank = perm_dimensions.size(); std::vector tile_assignment_dim(rank); for (int64 i = 0; i < rank; ++i) { - tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]); } Array tile_assignment = sharding.tile_assignment(); tile_assignment.Reshape(tile_assignment_dim); tile_assignment.Each([&](absl::Span indices, int64* value) { std::vector src_indices(indices.size(), -1); for (int64 i = 0; i < indices.size(); ++i) { - src_indices[dimensions[i]] = indices[i]; + src_indices[perm_dimensions[i]] = indices[i]; } *value = sharding.tile_assignment()(src_indices); }); - return HloSharding::Tile(tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } absl::optional ReshapeSharding(const Shape& source_shape, @@ -226,8 +234,14 @@ absl::optional ReshapeSharding(const Shape& source_shape, } } Array new_tile_assignment = sharding.tile_assignment(); + if (sharding.ReplicateOnLastTileDim()) { + target_tile_assignment_dimensions.push_back( + sharding.tile_assignment().dimensions().back()); + } new_tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(new_tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ReverseSharding(const HloSharding& sharding, @@ -245,7 +259,9 @@ HloSharding ReverseSharding(const HloSharding& sharding, } *device = sharding.tile_assignment()(original_indices); }); - return HloSharding::Tile(new_tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, @@ -331,17 +347,26 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, index_dim++; } } + + if (index_sharding.ReplicateOnLastTileDim()) { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dimensions().back()); + } + Array new_tile_assignment = index_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(output_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(output_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return index_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding GatherIndexSharding(const HloSharding& output_sharding, const HloInstruction* hlo) { + CHECK(hlo->opcode() == HloOpcode::kGather); if (output_sharding.IsTileMaximal()) { return output_sharding; } @@ -354,13 +379,28 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, output_sharding.tile_assignment().dim(i)); } } + int64 index_rank = hlo->operand(1)->shape().rank(); + + // Vector indices sharding is not supported yet. + if (index_rank > index_tile_assignment_dims.size()) { + index_tile_assignment_dims.insert( + index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1); + } + + if (output_sharding.ReplicateOnLastTileDim()) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dimensions().back()); + } + Array new_tile_assignment = output_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return output_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { @@ -430,13 +470,19 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding, if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { index_tile_assignment_dims.push_back(1); } + if (data_sharding.ReplicateOnLastTileDim()) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = data_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return data_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterDataSharding(const HloSharding& index_sharding, @@ -456,13 +502,19 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding, index_dim++; } } + if (index_sharding.ReplicateOnLastTileDim()) { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = index_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(data_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(data_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return index_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, @@ -589,9 +641,15 @@ absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( } passthrough_tile[offset_dim] = dim_partitions; } + if (operand_sharding.ReplicateOnLastTileDim()) { + passthrough_tile.push_back( + operand_sharding.tile_assignment().dimensions().back()); + } Array tile_assignment = operand_sharding.tile_assignment(); tile_assignment.Reshape(passthrough_tile); - return HloSharding::Tile(tile_assignment); + return operand_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate. @@ -625,12 +683,19 @@ absl::optional PassthroughGatherOutputOrScatterUpdateToOperand( } passthrough_tile[i] = dim_partitions; } + + if (update_or_gather_sharding.ReplicateOnLastTileDim()) { + passthrough_tile.push_back( + update_or_gather_sharding.tile_assignment().dimensions().back()); + } Array tile_assignment = update_or_gather_sharding.tile_assignment(); if (tile_assignment.num_elements() != Product(passthrough_tile)) { return absl::nullopt; } tile_assignment.Reshape(passthrough_tile); - return HloSharding::Tile(tile_assignment); + return update_or_gather_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } } // namespace @@ -777,5 +842,119 @@ std::vector DevicesForSharding( return devices; } +HloSharding PartiallyReplicateTiledShardingOnDims( + const HloSharding& sharding, const std::vector& dims_to_replicate) { + if (sharding.IsTileMaximal()) { + return sharding; + } + int64 group_count = 1; + for (int64 dim : dims_to_replicate) { + if (sharding.ReplicateOnLastTileDim()) { + CHECK_LT(dim, sharding.tile_assignment().num_dimensions()); + } + group_count *= sharding.tile_assignment().dim(dim); + } + if (group_count == 1) { + return sharding; + } + if (group_count == sharding.NumTiles()) { + return HloSharding::Replicate(); + } + std::vector dim_permutation( + sharding.tile_assignment().num_dimensions()); + std::iota(dim_permutation.begin(), dim_permutation.end(), 0); + absl::c_sort(dim_permutation, [&](const int64 a, const int64 b) { + return absl::c_linear_search(dims_to_replicate, a) < + absl::c_linear_search(dims_to_replicate, b); + }); + auto transposed = TransposeSharding(sharding, dim_permutation); + auto new_tile = transposed.tile_assignment(); + std::vector new_tile_shape( + sharding.tile_assignment().dimensions().begin(), + sharding.tile_assignment().dimensions().end()); + for (int64 dim : dims_to_replicate) { + new_tile_shape[dim] = 1; + } + if (sharding.ReplicateOnLastTileDim()) { + new_tile_shape.back() *= group_count; + } else { + new_tile_shape.push_back(group_count); + } + new_tile.Reshape(new_tile_shape); + return HloSharding::PartialTile(new_tile); +} + +HloSharding RemoveShapeDimensions(const HloSharding& sharding, + const std::vector& dims_to_remove) { + if (sharding.IsTileMaximal() || dims_to_remove.empty()) { + return sharding; + } + std::vector new_tile_shape; + new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() - + dims_to_remove.size()); + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_linear_search(dims_to_remove, i)) { + CHECK_EQ(sharding.tile_assignment().dim(i), 1); + } else { + new_tile_shape.push_back(sharding.tile_assignment().dim(i)); + } + } + auto new_tile = sharding.tile_assignment(); + new_tile.Reshape(new_tile_shape); + return sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(new_tile) + : HloSharding::Tile(new_tile); +} + +absl::optional TransposeShardingWithCollapsedDims( + const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) { + if (source.IsTileMaximal()) { + return source; + } + if (source.ReplicateOnLastTileDim() && + src_to_tgt.size() < source.tile_assignment().num_dimensions()) { + std::vector new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end()); + new_src_to_tgt.push_back(tgt_to_src.size()); + std::vector new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end()); + new_tgt_to_src.push_back(src_to_tgt.size()); + return TransposeShardingWithCollapsedDims(source, new_src_to_tgt, + new_tgt_to_src); + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return source.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(reshape_tiles) + : HloSharding::Tile(reshape_tiles); +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h index cc4068121ae..0de01fcab7e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -163,6 +163,24 @@ IdentityValueAndHloOpcodeForScatterReduceComputation( std::vector DevicesForSharding( const HloSharding& sharding, const std::vector& available_devices); +// Returns a sharding that replicates data across devices along the given +// dimensions in the original sharding. +HloSharding PartiallyReplicateTiledShardingOnDims( + const HloSharding& sharding, const std::vector& dims_to_replicate); + +// Returns a sharding the removes given tile dimensions. +// +// Precondition: if not tile maximal, the size of each tile dimension must be 1. +HloSharding RemoveShapeDimensions(const HloSharding& sharding, + const std::vector& dims_to_remove); + +// Similar to TransposeSharding(), but allows removing/adding non-partitioned +// dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing +// dimension. +absl::optional TransposeShardingWithCollapsedDims( + const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src); + } // namespace hlo_sharding_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index d395fddcc5d..0346e9077a0 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -703,6 +703,20 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) { + // Check for mixed precision. + const Shape& operand_shape = dynamic_reshape->operand(0)->shape(); + TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape)); + TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) == + ShapeUtil::ElementsIn(operand_shape)); + TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 == + dynamic_reshape->operand_count()); + for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) { + TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32); + } + return Status::OK(); +} + Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { // Check for mixed precision. const Shape& operand_shape = reshape->operand(0)->shape(); @@ -1023,7 +1037,7 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because + // Allow-list the following opcodes for mixed-precision check, because // they involve data pass through or grouping via tuples, where the // precisions of buffers can be different. case HloOpcode::kCall: diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 85b02e0518c..03fca5938ff 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -78,6 +78,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleReshape(HloInstruction* reshape) override; + Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override; Status HandleTranspose(HloInstruction* transpose) override; Status HandleParameter(HloInstruction*) override; Status HandleFusion(HloInstruction*) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8d8930615b2..11472f55792 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReducePrecision: case HloOpcode::kReplicaId: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: @@ -515,11 +516,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { continue; } - VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); + VLOG(5) << "Considering fusion of: " << instruction->ToString() + << " with operand " << operand->name(); if (!operand->IsFusible()) { VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; @@ -600,6 +602,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { VLOG(1) << FusionConfigToString(*fusion_config); module->set_config(module_config); } + + reachability_.reset(); + VLOG(1) << "Fusion count: " << fuse_count; return changed; @@ -709,4 +714,23 @@ HloInstruction::FusionKind InstructionFusion::ChooseKind( return HloInstruction::FusionKind::kLoop; } +bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index) { + auto operand = consumer->operand(operand_index); + auto it = reused_fusion_operands_.find(consumer); + if (it != reused_fusion_operands_.end() && it->second.contains(operand)) { + return true; + } + bool reuses = consumer->ReusesOperandElements(operand_index); + // If a parameter was reused, we can cache this information. Fusion + // computations only ever grow, so it becomes more likely that a parameter is + // reused, but a reused parameter will never become *not* reused. + if (reuses) { + // We cache the operand corresponding to the fusion parameter, because the + // parameter pointers would be invalidated after the next fusion. + reused_fusion_operands_[consumer].insert(operand); + } + return reuses; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 90d9da48e33..d51bf700371 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,4 +1,3 @@ -#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -138,6 +139,11 @@ class InstructionFusion : public HloModulePass { return config_collection_mode_; } + // Returns whether 'consumer' may reuse elements of its `operand_index`th + // operand. + bool ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index); + private: // The set of producers whose consumers we cannot fuse into. using HloInstructionSet = std::unordered_set; @@ -172,6 +178,11 @@ class InstructionFusion : public HloModulePass { // Configuration mode. FusionConfigCollection config_collection_mode_; + // Caches which operands are reused inside fusion computations. + absl::flat_hash_map> + reused_fusion_operands_; + TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion); }; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 7a4eefc1ab6..3444d4cae42 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:cholesky_expander", + "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 1649be2ca8f..a059482d832 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" @@ -81,6 +82,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index cc7fdeaf0f6..1446b55f5a8 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -52,6 +52,7 @@ InterpreterExecutable::InterpreterExecutable( } StatusOr InterpreterExecutable::Evaluate( + const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) { // Execute the graph using the HloEvaluator. tensorflow::mutex_lock lock(evaluator_lock_); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index ce68a8472f5..514ed029a22 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -51,7 +51,8 @@ class InterpreterExecutable : public InterpreterExecutableBase { static int64 ShapeSizeBytes(const Shape& shape); protected: - StatusOr Evaluate(const HloComputation& computation, + StatusOr Evaluate(const ServiceExecutableRunOptions* run_options, + const HloComputation& computation, absl::Span arg_literals) override TF_LOCKS_EXCLUDED(evaluator_lock_); diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc index 4b6a8aa5202..745750bffe1 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -50,11 +50,15 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( // TransferManager methods below. std::vector argument_buffers; argument_buffers.reserve(arguments.size()); + int device_ordinal = run_options->device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } for (auto& argument : arguments) { const ShapeTree& buffers = argument.Buffers(); argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), /*platform=*/nullptr, - /*device_ordinal=*/0)); + /*device_ordinal=*/device_ordinal)); auto in_it = buffers.begin(); auto out_it = argument_buffers.back().buffers().begin(); for (; in_it != buffers.end(); ++in_it, ++out_it) { @@ -118,7 +122,7 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( } TF_ASSIGN_OR_RETURN(Literal result_literal, - Evaluate(*computation, arg_literals)); + Evaluate(run_options, *computation, arg_literals)); // Shrink the generated dynamic shape into static shape. result_literal = result_literal.ToStatic(); diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.h b/tensorflow/compiler/xla/service/interpreter/executable_base.h index a02ab7af8d0..eb47841a179 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.h +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.h @@ -44,6 +44,7 @@ class InterpreterExecutableBase : public Executable { protected: virtual StatusOr Evaluate( + const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) = 0; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9e4bdeb2b2d..9416b11a07e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/rng.h" -#include "tensorflow/stream_executor/shared_memory_config.h" #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -182,15 +181,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return true; } - SharedMemoryConfig GetDeviceSharedMemoryConfig() override { - return SharedMemoryConfig::kDefault; - } - - port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override { - return port::Status{port::error::UNIMPLEMENTED, - "Shared memory not supported"}; - } - std::unique_ptr CreateEventImplementation() override { return nullptr; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index bea0f1fb93c..55569cfde0e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1891,7 +1891,7 @@ Status LayoutAssignment::RunOnComputation( ? ShapeUtil::GetSubshape(instruction->literal().shape(), buffer.index()) .layout() - : LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + : GetUnconstrainedLayout(buffer); TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer, /*mandatory=*/false)); @@ -2278,6 +2278,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kReduce: case HloOpcode::kReplicaId: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kRng: case HloOpcode::kRngBitGenerator: case HloOpcode::kRngGetAndUpdateState: diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index a04d056c618..def620bcee9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -338,6 +339,9 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { + return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + } // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 0371ce71874..6aa33a10d64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -244,16 +244,7 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( } else { total = 0; for (const auto* user : indexing_users[instruction]) { - int64 weight = 1; - // Concatenate is special: the index differs for each operand, so - // in the worst case we have to deal with as many index values as - // the number of operands of Concatenate. By considering the worst - // case, we are more conservative than necessary regarding - // refusing to fuse. - if (user->opcode() == HloOpcode::kConcatenate) { - weight = user->operand_count(); - } - total += index_usage_count[user] * weight; + total += index_usage_count[user]; } } for (const auto* operand : instruction->operands()) { @@ -298,15 +289,9 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( evaluate_fusion_computation(producer); } - // Sum up the total number of emitted ops. - int64 total = 0; - for (const auto& entry : index_usage_count) { - total += entry.second; - } - // Check that the code duplication has at most a factor of 15 (where 15 is an // arbitrary constant that seems to work). - return total > 15 * index_usage_count.size(); + return index_usage_count[producer] > 15; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index b01ae2efe43..2963d546380 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -415,9 +415,10 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, return inst; } -string IrName(string a) { - a.erase(std::remove(a.begin(), a.end(), '%'), a.end()); - return a; +string IrName(absl::string_view a) { + std::string s(a); + s.erase(std::remove(s.begin(), s.end(), '%'), s.end()); + return s; } string IrName(absl::string_view a, absl::string_view b) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 642965b6470..c0a55e4da33 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -87,7 +87,7 @@ string DumpModuleToString(const llvm::Module& module); // - joining all of the nonempty inputs by '.', and then // - removing all '%'s. // -string IrName(string a); +string IrName(absl::string_view a); string IrName(absl::string_view a, absl::string_view b); string IrName(const HloInstruction* a, absl::string_view b = ""); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index daf98478194..d89a9c2e0a5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -62,10 +62,11 @@ void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b) { llvm::Module* module = getModuleFromBuilder(b); for (size_t i = 0; i < operands.size(); ++i) { + auto* cast = + b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)); auto* store = b->CreateStore( - b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), - b->CreateInBoundsGEP(tuple.GetBasePointer(), - {b->getInt64(0), b->getInt64(i)})); + cast, b->CreateInBoundsGEP(tuple.GetBasePointer(), + {b->getInt64(0), b->getInt64(i)})); tuple.AnnotateLoadStoreInstructionWithMetadata(store); } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 0f7daa67800..2963fde9036 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -80,7 +80,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( } float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); @@ -119,14 +119,10 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( } } - // Get performance slowdown in seconds of prefetching current BufferInterval - // causing to other BufferIntervals. - float alternate_mem_slowdown = - GetInstructionElapsedDueToMemorySlowdown(interval.size); - - // Divide by the size of the buffer to prioritize smaller buffers that will - // give the largest alternate memory benefit. - return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size; + // Penalize larger buffers by dividing the benefit by the square root of the + // size. Empirically, we observed this resulted in better performance compared + // to dividing by the size. + return alternate_mem_benefit / std::sqrt(interval.size); } int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( @@ -236,15 +232,26 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( } int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( - const HloUse& use, int64 start_time, int64 end_time) const { + const Shape& shape, int64 start_time, int64 end_time, + const HloUse* use) const { return end_time - min_overlap_count_; } +int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, int64 prefetch_end_time) const { + return std::max(earliest_prefetch_start_time, + prefetch_end_time - max_overlap_count_); +} + void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { end_time_ = end_time; - current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_); + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + current_prefetch_time_ = + PreferredPrefetchStartTime(shape, start_time, end_time, end_time); } int64 InstructionCountPrefetchIntervalPicker::Next() { @@ -361,18 +368,22 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( } int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( - const HloUse& use, int64 start_time, int64 end_time) const { - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); + const Shape& shape, int64 start_time, int64 end_time, + const HloUse* use) const { // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); - // Estimate the time we would save by having this op in alternate memory. - float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use.instruction, use.operand_number, - /*output_in_alternate_mem=*/false); - float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + // If there is a use, estimate the time we would save by having this op in + // alternate memory. + float inst_elapsed_reduction = 0.0f; + if (use) { + float elapsed_time = + cost_analysis_.GetInstructionElapsed(*use->instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use->instruction, use->operand_number, + /*output_in_alternate_mem=*/false); + inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + } int end_nest_level = while_nest_level_[end_time]; // Find the latest time we're allowed to start prefetching. @@ -390,6 +401,33 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( return latest_prefetch_time; } +int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, int64 prefetch_end_time) const { + // Between the earliest and latest prefetch interval, find the interval + // closest to the preferred interval and start iterating from there. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); + int64 preferred_prefetch_start_time = earliest_prefetch_start_time; + float preferred_interval = + preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed; + float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, + prefetch_end_time); + int end_nest_level = while_nest_level_[prefetch_end_time]; + for (int64 prefetch_start_time = earliest_prefetch_start_time + 1; + prefetch_start_time <= latest_prefetch_start_time; + ++prefetch_start_time) { + float interval = + GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); + if (while_nest_level_[prefetch_start_time] == end_nest_level && + std::abs(preferred_interval - interval) < + std::abs(preferred_interval - best_interval)) { + best_interval = interval; + preferred_prefetch_start_time = prefetch_start_time; + } + } + return preferred_prefetch_start_time; +} + int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const { // Iterate towards the beginning until we find a suitable end time that is the @@ -422,7 +460,8 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; - latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time); + latest_prefetch_time_ = + LatestPrefetchStartTime(shape, start_time, end_time, &use); // Find the earliest time we're allowed to start prefetching. float max_interval = max_async_copy_to_overlap_ratio_ * @@ -443,24 +482,10 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, return; } - // Between the earliest and latest prefetch interval, find the interval - // closest to the preferred interval and start iterating from there. - int64 starting_prefetch_time = earliest_prefetch_time_; + int64 starting_prefetch_time = PreferredPrefetchStartTime( + shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_); float preferred_interval = preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_; - float best_interval = - GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_); - for (int64 prefetch_time = earliest_prefetch_time_ + 1; - prefetch_time <= latest_prefetch_time_; ++prefetch_time) { - float interval = - GetLogicalIntervalElapsed(prefetch_time, end_logical_time_); - if (while_nest_level_[prefetch_time] == end_nest_level && - std::abs(preferred_interval - interval) < - std::abs(preferred_interval - best_interval)) { - best_interval = interval; - starting_prefetch_time = prefetch_time; - } - } VLOG(4) << "Interval min/max/preferred = " << min_interval << " " << max_interval << " " << preferred_interval << " prefetch time earliest/latest/starting = " @@ -570,7 +595,8 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( absl::optional CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { return cost_analysis_.GetMemoryBoundedness(interval); } @@ -610,7 +636,9 @@ std::string MemorySpaceAssignment::AllocationValue::ToShortString() const { } void AlternateMemoryBestFitHeap::CreateAllocationValues( - const HloValue* value, std::vector* allocation_values) { + const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval, + std::vector& allocation_values) const { + const HloValue* value = buffer_interval.buffer; VLOG(3) << "Creating AllocationValues for: " << value->ToString(); // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast) @@ -638,10 +666,10 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( // Create an AllocationValue for each non-trivial position. absl::flat_hash_set computations; - int beginning_idx = allocation_values->size(); + int beginning_idx = allocation_values.size(); for (int i = 0; i < positions.size(); ++i) { const HloPosition& position = positions.at(i); - allocation_values->emplace_back(value, position); + allocation_values.emplace_back(value, position, buffer_interval.size); } std::vector uses(value->uses()); @@ -662,8 +690,8 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( HloComputation* use_computation = use.instruction->parent(); AllocationValue* last_allocation_value = nullptr; - for (int i = beginning_idx; i < allocation_values->size(); ++i) { - AllocationValue* allocation_value = &allocation_values->at(i); + for (int i = beginning_idx; i < allocation_values.size(); ++i) { + AllocationValue* allocation_value = &allocation_values.at(i); if (allocation_value->computation() == use_computation && instruction_schedule.at( allocation_value->defining_position().instruction) < use_time) { @@ -674,9 +702,9 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( last_allocation_value->AddUse(use, use_time); } - for (int i = beginning_idx; i < allocation_values->size(); ++i) { + for (int i = beginning_idx; i < allocation_values.size(); ++i) { VLOG(3) << "Created allocation value: " - << allocation_values->at(i).ToString(); + << allocation_values.at(i).ToString(); } } @@ -731,9 +759,9 @@ void AlternateMemoryBestFitHeap::FindAliases( } } -std::vector +std::vector AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const AlternateMemoryBestFitHeap::BufferInterval& interval) const { std::vector colocated_intervals; std::vector worklist = {&interval}; while (!worklist.empty()) { @@ -862,7 +890,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( } void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const AlternateMemoryBestFitHeap::BufferInterval& interval, std::string* debug_str) const { // Columns in buffer information: // buffer_id: int. This value can be used to match the allocation in @@ -920,27 +948,27 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( } void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const AllocationValue& value, const MemorySpaceAssignment::Allocation& allocation, - std::string* debug_str) const { + std::string& debug_str) const { // Columns in allocation information: // buffer_id: int. This value can be used the match with buffer info. // size: int. In bytes. // offset: int. In bytes. // start_time: int. Logical start time of the allocation. // end_time: int. Logical end time of the allocation. - if (debug_str->empty()) { + if (debug_str.empty()) { // Append the column names. - absl::StrAppend(debug_str, "buffer_id,size,offset,start_time,end_time\n"); + absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n"); } if (allocation.memory_space() == MemorySpace::kAlternate) { const HloBuffer& buffer = - alias_analysis_.GetBufferContainingValue(*interval.buffer); - absl::StrAppend(debug_str, buffer.id(), ","); - absl::StrAppend(debug_str, interval.size, ","); - absl::StrAppend(debug_str, allocation.chunk().offset, ","); - absl::StrAppend(debug_str, allocation.start_time(), ","); - absl::StrAppend(debug_str, allocation.end_time(), "\n"); + alias_analysis_.GetBufferContainingValue(*value.value()); + absl::StrAppend(&debug_str, buffer.id(), ","); + absl::StrAppend(&debug_str, value.size(), ","); + absl::StrAppend(&debug_str, allocation.chunk().offset, ","); + absl::StrAppend(&debug_str, allocation.start_time(), ","); + absl::StrAppend(&debug_str, allocation.end_time(), "\n"); } } @@ -952,7 +980,7 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { options_.dump_fn("allocinfo", allocation_info_str_); } -HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { +HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -971,6 +999,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } } + for (const auto& interval : sorted_buffer_intervals) { + auto colocated_intervals = GetSortedColocatedIntervals(interval); + if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { + // Increment the reserved part of alternate memory so that it is not + // available for other buffers. + reserved_in_bytes_ += options_.size_fn(*interval.buffer); + } + } + VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_; + for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { continue; @@ -994,12 +1032,17 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } + if (interval.size > available_heap_size()) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because the buffer is larger than the heap size."; + continue; + } + auto colocated_intervals = GetSortedColocatedIntervals(interval); if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { VLOG(3) << "Interval " << interval.buffer->ToShortString() - << " is reserved in the alternate memory. Total reserved bytes = " - << reserved_in_bytes_; + << " is reserved in the alternate memory."; for (const BufferInterval* colocated_interval : colocated_intervals) { const HloValue* value = colocated_interval->buffer; // Color all of the aliased reserved buffers here because reserved @@ -1015,10 +1058,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { options_.alternate_memory_space); } } - // Increment the reserved part of alternate memory so that it is not - // available for other buffers. Since all colocated intervals should have - // the same size, just use the first one. - reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer); continue; } @@ -1039,16 +1078,46 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AppendBufferInfoDebugString(interval, &buffer_info_str_); + std::vector allocation_values; + CreateAllocationValuesFromColocatedIntervals(colocated_intervals, + allocation_values); + // Retry allocating this value with larger limits if allocation fails. + bool repacked = false; for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { - final_retry_ = (retry_number == options_.max_retries - 1); + bool final_retry = (retry_number == options_.max_retries - 1); options_.prefetch_interval_picker->SetRetryNumber(retry_number); - bool success = AllocateColocatedIntervals(colocated_intervals); - if (success) { + Result result = + AllocateAllocationValues(absl::MakeSpan(allocation_values)); + VLOG(2) << "Allocation result = " + << absl::StrFormat("%x", static_cast(result)); + if (result_requires_uncommit(result) || + (!final_retry && result_failed_because_of_async_copy(result))) { + UncommitPendingChunks(absl::MakeSpan(allocation_values)); + VLOG(2) << "Couldn't allocate. Retry number " << retry_number; + } else if (result_is(result, Result::kFailOutOfMemory) && + num_repacks_ < options_.max_repacks && !repacked) { + UncommitPendingChunks(absl::MakeSpan(allocation_values)); + ++num_repacks_; + repacked = true; + CHECK_NE(options_.repacker, nullptr); + std::vector + repack_allocation_blocks; + ExportAllocationsForRepacking(repack_allocation_blocks); + VLOG(2) << "Repacking."; + auto repack_status = + options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks)); + CHECK_EQ(repack_status.status(), Status::OK()); + VLOG(2) << "Repack complete. Modified = " << *repack_status; + if (*repack_status) { + ImportRepackedAllocations(); + --retry_number; + } + } else { + FinalizeAllocations(absl::MakeSpan(allocation_values)); break; } - VLOG(2) << "Couldn't allocate. Retry number " << retry_number; } } @@ -1061,9 +1130,10 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return result_; } -bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( - const std::vector& - colocated_intervals) { +void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values) { // TODO(berkin): For now, place the phi values due to conditionals in // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -1084,25 +1154,29 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( } // Create AllocationValues for all the colocated intervals. - std::vector allocation_values; for (const auto& colocated_interval : colocated_intervals) { - CreateAllocationValues(colocated_interval->buffer, &allocation_values); + CreateAllocationValues(*colocated_interval, allocation_values); } FindAliases(&allocation_values); +} + +AlternateMemoryBestFitHeap::Result +AlternateMemoryBestFitHeap::AllocateAllocationValues( + absl::Span allocation_values) { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); // Data structure to contain the preferred offset for a given computation. // We ensure that the same offset will be allocated outside the while loop // as well as inside the while loop. - absl::flat_hash_map + absl::flat_hash_map preferred_offset_for_computation; - bool allocation_success = true; - for (auto& allocation_value : allocation_values) { + Result result = Result::kSuccess; + for (AllocationValue& allocation_value : allocation_values) { int64 definition_time = instruction_schedule.at(allocation_value.defining_instruction()); - absl::optional preferred_offset; + AliasedOffset* preferred_offset = nullptr; auto preferred_offset_it = preferred_offset_for_computation.find(allocation_value.computation()); if (preferred_offset_it != preferred_offset_for_computation.end()) { @@ -1201,10 +1275,13 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( } } - // Bitcasts don't define buffers and don't directly consume buffers. Skip - // allocating buffers for bitcast uses. The uses that feed from bitcasts - // will be handled specially. - if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) { + // Bitcasts don't define buffers and don't directly consume buffers. Skip + // allocating buffers for bitcast uses (unless they are the root + // instruction). The uses that feed from bitcasts will be handled + // specially. + if (hlo_use.instruction->opcode() != HloOpcode::kBitcast || + hlo_use.instruction == + hlo_use.instruction->parent()->root_instruction()) { AllocationRequest request; // Rarely, (e.g., when conditional true and false parameters are the // same), definition time can be the time of the conditional and use @@ -1212,20 +1289,19 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( request.start_time = std::min(definition_time, use_time); request.end_time = use_time; request.latest_prefetch_time = latest_prefetch_time; - request.size = colocated_intervals[0]->size; + request.size = allocation_value.size(); request.allow_no_copy_alternate_mem_allocation = allow_no_copy_alternate_mem_allocation; request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_offset = preferred_offset; request.use = &use; request.allocation_value = &allocation_value; - if (!AllocateSegment(request)) { + result_mark(AllocateSegment(request), result); + if (result_requires_uncommit(result)) { // If the allocation finding failed (e.g., due to running out of // asynchronous copies), then fall back to allocating the buffer // entirely in the default memory. - UncommitPendingChunks(); - allocation_success = false; - break; + return result; } // If there are multiple uses, they can try using the memory allocation @@ -1248,27 +1324,11 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( if (hlo_use.instruction->opcode() == HloOpcode::kWhile && aliased_allocation->memory_space() == MemorySpace::kAlternate) { preferred_offset_for_computation[hlo_use.instruction->while_body()] = - aliased_allocation->chunk().offset; - } - } - if (!allocation_success) { - break; - } - } - if (allocation_success) { - for (AllocationValue& allocation_value : allocation_values) { - for (auto& allocation : *allocation_value.allocation_sequence()) { - AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation, - &allocation_info_str_); - allocations_->push_back(std::move(allocation)); + GetAliasedOffset(*aliased_allocation); } } } - - pending_chunks_.clear(); - pending_async_copies_.clear(); - pending_required_assignments_.clear(); - return allocation_success; + return result; } bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { @@ -1305,6 +1365,28 @@ absl::optional AsynchronousCopyOrdering::ViolatesOrdering( return absl::nullopt; } +AlternateMemoryBestFitHeap::AliasedOffset* +AlternateMemoryBestFitHeap::GetAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation) { + auto aliased_offset_it = aliased_offset_map_.find(&allocation); + CHECK(aliased_offset_it != aliased_offset_map_.end()); + return aliased_offset_it->second; +} + +void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation, + AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) { + CHECK(allocation.memory_space() == MemorySpace::kAlternate); + CHECK(!aliased_offset_map_.contains(&allocation)); + if (!aliased_offset) { + aliased_offsets_.push_back({allocation.chunk().offset}); + aliased_offset = &aliased_offsets_.back(); + } + CHECK_EQ(allocation.chunk().offset, aliased_offset->offset); + CHECK(aliased_offset->allocations.insert(&allocation).second); + aliased_offset_map_[&allocation] = aliased_offset; +} + /*static*/ MemorySpaceAssignment::Allocation* AlternateMemoryBestFitHeap::GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) { @@ -1345,27 +1427,87 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( // Find the earliest use. const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); auto uses = buffer->uses(); - auto first_use = - absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) { - return instruction_schedule.at(lhs.instruction) < - instruction_schedule.at(rhs.instruction); - }); + auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) { + return instruction_schedule.at(lhs.instruction) < + instruction_schedule.at(rhs.instruction); + }; + auto first_use = absl::c_min_element(uses, use_schedule_compare); int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction); + // Find the latest use time. + int64 last_use_time = instruction_schedule.at( + absl::c_max_element(uses, use_schedule_compare)->instruction); + for (const HloValue* colocation : prefetch_candidate->colocations) { + last_use_time = std::max( + last_use_time, + instruction_schedule.at( + absl::c_max_element(colocation->uses(), use_schedule_compare) + ->instruction)); + } + + int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1; + int64 end_of_program_prefetch_start_time = + options_.prefetch_interval_picker->PreferredPrefetchStartTime( + buffer->defining_position().shape(), last_use_time, + end_of_program_prefetch_end_time, end_of_program_prefetch_end_time); + VLOG(2) << "last use time = " << last_use_time + << ", end-of-program prefetch start time = " + << end_of_program_prefetch_start_time; + bool free_buffer = + (end_of_program_prefetch_start_time > last_use_time && + end_of_program_prefetch_start_time < end_of_program_prefetch_end_time); + int64 cross_program_prefetch_end_time = + free_buffer ? last_use_time : prefetch_candidate->end; + AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate.chunk, prefetch_candidate->start, - prefetch_candidate->end, latest_prefetch_time, &allocations); + cross_program_prefetch_end_time, latest_prefetch_time, + &allocations, /*aliased_offset=*/nullptr, + /*is_cross_program_prefetch=*/true); absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); + AliasedOffset* cross_program_prefetch_offset = + GetAliasedOffset(*allocations.back()); + + if (free_buffer) { + VLOG(2) << "Adding an end-of-program prefetch for freed " + "cross-program-prefetched buffer."; + AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, + chunk_candidate.chunk, end_of_program_prefetch_start_time, + end_of_program_prefetch_end_time, + end_of_program_prefetch_end_time, &allocations, + cross_program_prefetch_offset); + CHECK_EQ(cross_program_prefetch_offset->offset, + allocations.back()->chunk().offset); + } + for (auto& allocation : allocations) { allocations_->push_back(std::move(allocation)); } - pending_chunks_.clear(); - pending_async_copies_.clear(); - pending_required_assignments_.clear(); + // Add a repack allocation block for the Allocation objects in alternate + // memory. + CHECK_EQ(repack_allocation_blocks_.size(), 0); + for (const auto& allocation : *allocations_) { + if (allocation->memory_space() == MemorySpace::kAlternate) { + repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( + allocation->start_time(), allocation->end_time(), + allocation->chunk().size, allocation->chunk().offset, + static_cast(repack_allocation_blocks_.size()), + allocation.get())); + RepackAllocationBlock* inserted = &repack_allocation_blocks_.back(); + for (RepackAllocationBlock& colocation : repack_allocation_blocks_) { + colocation.colocations.push_back(inserted); + if (&colocation != inserted) { + inserted->colocations.push_back(&colocation); + } + } + } + } + + ClearPendingChunks(); } -absl::optional +absl::optional AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, int64 time) const { auto required_assignment_it = required_assignments_.find(buffer); @@ -1383,7 +1525,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, return required_assignment_at_time; } -absl::optional +absl::optional AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( const AllocationValue::Use& use) const { absl::optional required_assignment; @@ -1409,26 +1551,26 @@ AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, const MemorySpaceAssignment::Allocation* aliased_allocation) { - absl::optional chunk; + AliasedOffset* offset = nullptr; if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { - chunk = aliased_allocation->chunk(); + offset = GetAliasedOffset(*aliased_allocation); } AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), - chunk); + offset); } void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloValue* value, const HloInstruction* instruction, MemorySpaceAssignment::MemorySpace memory_space, int64 time, - absl::optional chunk) { + AliasedOffset* offset) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); if (existing_required_assignment) { CHECK(memory_space == existing_required_assignment->memory_space) << "inst = " << instruction->ToString() << " at " << time; - CHECK((!chunk && !existing_required_assignment->chunk) || - chunk->offset == existing_required_assignment->chunk->offset); + CHECK((!offset && !existing_required_assignment->offset) || + offset == existing_required_assignment->offset); VLOG(3) << "Not adding required assignment because there is one already: " << value->ToShortString() << " at " << time << " at " << (memory_space == MemorySpace::kDefault ? "def" : "alt"); @@ -1436,7 +1578,7 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment( VLOG(3) << "Adding required assignment: " << value->ToShortString() << " at " << time << " at " << (memory_space == MemorySpace::kDefault ? "def" : "alt"); - RequiredMemoryAssignment required_assignment{memory_space, time, chunk}; + RequiredMemoryAssignment required_assignment{memory_space, time, offset}; required_assignments_[value].push_back(required_assignment); pending_required_assignments_.push_back({value, required_assignment}); } @@ -1444,13 +1586,13 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment( void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, - MemorySpace memory_space, absl::optional chunk) { + MemorySpace memory_space, AliasedOffset* offset) { const HloValue* value = &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); int64 instruction_time = hlo_live_range_.instruction_schedule().at(instruction); AddRequiredAssignment(value, instruction, memory_space, instruction_time, - chunk); + offset); } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { @@ -1539,7 +1681,38 @@ bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( return false; } -void AlternateMemoryBestFitHeap::UncommitPendingChunks() { +void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( + std::vector& allocations) { + for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { + allocations.push_back(&allocation_block); + } +} + +void AlternateMemoryBestFitHeap::ImportRepackedAllocations() { + interval_tree_ = {}; + for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { + MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation; + VLOG(3) << "Moved " << allocation->ToString() << ", size " + << allocation->chunk().size << ", (" << allocation_block.start_time + << ", " << allocation_block.end_time << ") from " + << allocation_block.initial_offset << " to " + << allocation_block.offset; + allocation_block.allocation->mutable_chunk()->offset = + allocation_block.offset; + interval_tree_.Add(allocation_block.start_time, allocation_block.end_time, + {allocation_block.offset, allocation_block.size}); + allocation_block.initial_offset = allocation_block.offset; + allocation_block.offset = -1; + } +} + +void AlternateMemoryBestFitHeap::UncommitPendingChunks( + absl::Span allocation_values) { + // Clear the allocation sequence of the allocation values so that in case we + // retry allocation after uncommitting. + for (AllocationValue& allocation_value : allocation_values) { + allocation_value.allocation_sequence()->clear(); + } for (const auto& interval_and_chunk : pending_chunks_) { const BufferInterval& interval = interval_and_chunk.first; const Chunk& chunk = interval_and_chunk.second.chunk; @@ -1568,8 +1741,8 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { ? "def" : "alt") << " time = " << required_assignment.time << " off = " - << (required_assignment.chunk ? required_assignment.chunk->offset - : -1); + << (required_assignment.offset ? required_assignment.offset->offset + : -1); for (auto it = required_assignment_vector.begin(); it != required_assignment_vector.end(); ++it) { if (*it == value_and_required_assignment.second) { @@ -1578,9 +1751,56 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { } } } + ClearPendingChunks(); +} + +void AlternateMemoryBestFitHeap::FinalizeAllocations( + absl::Span allocation_values) { + absl::flat_hash_map> + colocation_map; + for (AllocationValue& allocation_value : allocation_values) { + for (auto& allocation : *allocation_value.allocation_sequence()) { + AppendAllocationInfoDebugString(allocation_value, *allocation, + allocation_info_str_); + allocations_->push_back(std::move(allocation)); + MemorySpaceAssignment::Allocation* inserted_allocation = + allocations_->back().get(); + if (inserted_allocation->memory_space() == MemorySpace::kAlternate) { + colocation_map[GetAliasedOffset(*inserted_allocation)].push_back( + inserted_allocation); + } + } + } + // The allocations that have the same AliasedOffset need to be colocated. + // Export these to repack_allocation_blocks_ so that we can repack them to + // reduce fragmentation. + for (auto& colocation : colocation_map) { + std::vector colocations; + for (MemorySpaceAssignment::Allocation* colocated_allocation : + colocation.second) { + repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( + colocated_allocation->start_time(), colocated_allocation->end_time(), + colocated_allocation->chunk().size, + colocated_allocation->chunk().offset, + static_cast(repack_allocation_blocks_.size()), + colocated_allocation)); + colocations.push_back(&repack_allocation_blocks_.back()); + } + for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block : + colocations) { + repack_block->colocations = colocations; + } + } + ClearPendingChunks(); +} + +void AlternateMemoryBestFitHeap::ClearPendingChunks() { pending_chunks_.clear(); pending_async_copies_.clear(); pending_required_assignments_.clear(); + aliased_offset_map_.clear(); + aliased_offsets_.clear(); } void AlternateMemoryBestFitHeap::AddToPendingChunks( @@ -1593,7 +1813,7 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks( CommitChunk(buffer_interval, chunk_candidate); } -bool AlternateMemoryBestFitHeap::AllocateSegment( +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( const AllocationRequest& request) { auto allocation_sequence = request.allocation_value->allocation_sequence(); // start_time == end_time is a special case where the value is consumed @@ -1604,7 +1824,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( GetLiveAllocationAt(*allocation_sequence, request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use->hlo_use); - return true; + return Result::kSuccess; } const HloPosition& defining_position = @@ -1656,24 +1876,37 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( const auto& prev_allocation = allocation_sequence->back(); CHECK(prev_allocation->memory_space() == required_assignment_at_start->memory_space); - CHECK_EQ(prev_allocation->chunk().offset, - required_assignment_at_start->chunk->offset); + CHECK_EQ(GetAliasedOffset(*prev_allocation), + required_assignment_at_start->offset); prev_allocation->Extend(request.start_time); } else { + absl::optional aliased_chunk = absl::nullopt; + if (required_assignment_at_start->memory_space == + MemorySpace::kAlternate) { + aliased_chunk = + Chunk{required_assignment_at_start->offset->offset, request.size}; + } allocation_sequence->push_back( absl::make_unique( defining_position, required_assignment_at_start->memory_space, - required_assignment_at_start->chunk, request.start_time, - request.start_time)); + aliased_chunk, request.start_time, request.start_time)); + if (required_assignment_at_start->memory_space == + MemorySpace::kAlternate) { + CreateOrAddToAliasedOffset(*allocation_sequence->back(), + required_assignment_at_start->offset); + } } } + Result allocation_result = Result::kSuccess; // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && - request.allow_no_copy_alternate_mem_allocation && - AllocateInAlternateMemoryNoCopy(request)) { - return true; + request.allow_no_copy_alternate_mem_allocation) { + allocation_result = AllocateInAlternateMemoryNoCopy(request); + if (allocation_result == Result::kSuccess) { + return Result::kSuccess; + } } auto prev_allocation_it = allocation_sequence->rbegin(); @@ -1692,8 +1925,10 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( (*prev_allocation_it)->defining_position() == defining_position) { // If there was an allocation for this HloValue that was in the alternate // memory space, we also need to perform an eviction. - if (!Evict(request)) { - return false; + Result eviction_result = Evict(request); + if (eviction_result != Result::kSuccess) { + // A non-success eviction requires us to uncommit previous allocations. + return result_mark(Result::kFailRequiresUncommit, eviction_result); } prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { @@ -1714,38 +1949,36 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( << "Not trying to prefetch because use requires buffer in default mem."; (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); - return true; + return Result::kSuccess; } // Finally, try to prefetch the buffer into alternate memory. - if (Prefetch(request, **prev_allocation_in_default_mem_it)) { - return true; - } - if (!final_retry_ && prefetch_failed_due_to_async_copy_) { - // If prefetching failed due to asynchronous copy and we're not in our final - // try, return false (failure) so that we can retry this interval with - // larger limits. - return false; + Result prefetch_result = + Prefetch(request, **prev_allocation_in_default_mem_it); + if (prefetch_result == Result::kSuccess) { + return Result::kSuccess; } + result_mark(prefetch_result, allocation_result); // If the end assignment was required to be in alternate memory but that // wasn't possible, then this allocation is invalid. if (required_memory_space_at_end == MemorySpace::kAlternate) { - return false; + return result_mark(Result::kFailRequiresUncommit, allocation_result); } // If a copy wasn't inserted, then add this use to the latest allocation in // default memory. (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); - return true; + return allocation_result; } void AlternateMemoryBestFitHeap::AddAsyncCopy( const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations) { + MemorySpaceAssignment::AllocationSequence* allocations, + AliasedOffset* aliased_offset, bool is_cross_program_prefetch) { VLOG(3) << "Copy to " << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault ? "default" @@ -1757,7 +1990,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( allocations->push_back( absl::make_unique( prev_allocation, memory_space, chunk, start_time, end_time, - copy_done_schedule_before_time)); + copy_done_schedule_before_time, is_cross_program_prefetch)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. @@ -1767,6 +2000,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time, kDummyChunk); async_copy_ordering_.AddCopy(pending_async_copies_.back()); + CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset); } else { eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time, kDummyChunk); @@ -1805,7 +2039,8 @@ AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time, return async_copy_ordering_.ViolatesOrdering(start_time, end_time); } -bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( +AlternateMemoryBestFitHeap::Result +AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { MemorySpaceAssignment::Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; @@ -1824,7 +2059,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( } if (!can_eliminate_copy) { - return false; + return Result::kFailPrevAllocationNotInAlternateMem; } const HloPosition& defining_position = @@ -1832,7 +2067,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( defining_position.shape(), request.start_time + 1, request.end_time)) { - return false; + return Result::kFailLiveRangeTooLong; } BufferInterval alternate_mem_interval; @@ -1842,9 +2077,9 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( alternate_mem_interval.start = request.start_time; // Prefer the offset that was previously used for the previous allocation. - absl::optional preferred_offset; + AliasedOffset* preferred_offset = nullptr; if (prev_allocation != nullptr) { - preferred_offset = prev_allocation->chunk().offset; + preferred_offset = GetAliasedOffset(*prev_allocation); // If there is a previous allocation, set the start time one after the end // of the previous allocation's end. alternate_mem_interval.start = prev_allocation->end_time() + 1; @@ -1854,13 +2089,13 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( // Sanity check that if there is a preferred offset provided in the request, // it matches with the previous allocation. CHECK(!preferred_offset || request.preferred_offset == preferred_offset) - << "preferred_offset = " << *preferred_offset - << ", request.preferred_offset = " << *request.preferred_offset; + << "preferred_offset = " << preferred_offset->offset + << ", request.preferred_offset = " << request.preferred_offset->offset; preferred_offset = request.preferred_offset; } VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = " - << (preferred_offset ? *preferred_offset : -1); + << (preferred_offset ? preferred_offset->offset : -1); // In case there are additional uses after this use, we rely on the last use // time to try to reserve a chunk in the heap simulator. This is to prevent // the following scenario: @@ -1908,15 +2143,19 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( absl::make_unique( defining_position, MemorySpace::kAlternate, chunk_candidate->chunk, request.start_time, request.end_time)); + CreateOrAddToAliasedOffset( + *request.allocation_value->allocation_sequence()->back(), + preferred_offset); } request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); - return true; + return Result::kSuccess; } - return false; + return Result::kFailOutOfMemory; } -bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( + const AllocationRequest& request) { CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); MemorySpaceAssignment::Allocation* prev_allocation = request.allocation_value->allocation_sequence()->back().get(); @@ -1970,7 +2209,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, eviction_start_time, prev_allocation->end_time(), eviction_end_time, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + /*aliased_offset=*/nullptr); } else { if (eviction_violates_outstanding_copies) { VLOG(3) << "This violates the maximum async copies."; @@ -1988,7 +2228,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, time, time + 1, time + 1, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + /*aliased_offset=*/nullptr); eviction_scheduled = true; break; } @@ -2005,22 +2246,27 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { << " and " << hlo_live_range_.flattened_instruction_sequence() .instructions()[eviction_end_time]; - return false; + // return false; + return Result::kFailOutOfAsyncCopies; } } - return true; + // return true; + return Result::kSuccess; } int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime( const AllocationRequest& request, int64 earliest_prefetch_time) const { int64 prefetch_end_time = request.latest_prefetch_time; + const HloUse& use = request.use->hlo_use; + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); for (int retry_number = 0; retry_number < options_.prefetch_copy_done_reorder_max_retries; ++retry_number) { int64 latest_prefetch_time = options_.prefetch_interval_picker->LatestPrefetchStartTime( - request.use->hlo_use, earliest_prefetch_time, prefetch_end_time); + shape, earliest_prefetch_time, prefetch_end_time, &use); VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time << ", earliest prefetch start time = " << earliest_prefetch_time << ", prefetch end time = " << prefetch_end_time; @@ -2058,7 +2304,7 @@ int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime( return prefetch_end_time; } -bool AlternateMemoryBestFitHeap::Prefetch( +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( const AllocationRequest& request, const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) { // Try partially placing the buffer in the alternate space. The time that is @@ -2092,15 +2338,12 @@ bool AlternateMemoryBestFitHeap::Prefetch( BufferInterval alternate_mem_interval; alternate_mem_interval.buffer = request.allocation_value->value(); alternate_mem_interval.size = request.size; - // If any of the prefetch intervals couldn't be used due to number of - // outstanding async copy limit or async copy ordering, set - // prefetch_failed_due_to_async_copy_. - prefetch_failed_due_to_async_copy_ = false; // While uses might be allowed to have additional outstanding prefetches. int64 extra_async_copy_limit = request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile ? options_.while_use_extra_outstanding_prefetch_limit : 0; + Result result = Result::kSuccess; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); CHECK_LT(alternate_mem_interval.start, prefetch_end_time); @@ -2111,14 +2354,14 @@ bool AlternateMemoryBestFitHeap::Prefetch( if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, prefetch_end_time)) { VLOG(4) << "This would violate asynchronous copy ordering."; - prefetch_failed_due_to_async_copy_ = true; + result_mark(Result::kFailViolatesAsyncCopyOrdering, result); continue; } if (ViolatesMaximumOutstandingAsyncCopies( alternate_mem_interval.start, prefetch_end_time, /*is_prefetch=*/true, extra_async_copy_limit)) { VLOG(4) << "This would violate the outstanding async copy limit."; - prefetch_failed_due_to_async_copy_ = true; + result_mark(Result::kFailOutOfAsyncCopies, result); continue; } @@ -2138,20 +2381,27 @@ bool AlternateMemoryBestFitHeap::Prefetch( AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, chunk_candidate->chunk, alternate_mem_interval.start, request.end_time, prefetch_end_time, - request.allocation_value->allocation_sequence()); + request.allocation_value->allocation_sequence(), + request.preferred_offset); request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); - prefetch_failed_due_to_async_copy_ = false; - return true; + return Result::kSuccess; } + result_mark(Result::kFailOutOfMemory, result); + } + // If we didn't consider any prefetch intervals, then the live range was too + // short. + if (result == Result::kSuccess) { + return Result::kFailLiveRangeTooShort; + } else { + return result; } - return false; } absl::optional AlternateMemoryBestFitHeap::FindBestChunkCandidate( - const AllocationRequest& request, absl::optional preferred_offset, + const AllocationRequest& request, const AliasedOffset* preferred_offset, BufferInterval* alternate_mem_interval) const { int64 end_time = request.end_time; if (!preferred_offset) { @@ -2197,8 +2447,8 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( // only. alternate_mem_interval->end = end_time; ChunkCandidate chunk_candidate = - FindChunkCandidate(*alternate_mem_interval, *preferred_offset); - if (chunk_candidate.chunk.offset == *preferred_offset) { + FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset); + if (chunk_candidate.chunk.offset == preferred_offset->offset) { return chunk_candidate; } return absl::nullopt; @@ -2252,8 +2502,8 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( return x_memory_boundedness > y_memory_boundedness; } // Tie-break if the memory boundedness is the same. - return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()( - x, y); + return GlobalDecreasingSizeBestFitHeap< + HloValue>::GetSpatialBufferIntervalCompare()(x, y); }; } @@ -2295,6 +2545,9 @@ bool IsCrossProgramPrefetchCandidate( return value.instruction()->parent() == value.instruction()->GetModule()->entry_computation() && value.instruction()->opcode() == HloOpcode::kParameter && + (!value.shape().has_layout() || + value.shape().layout().memory_space() != + options.alternate_memory_space) && value.index().size() == 1 && value.shape().IsArray() && !value.uses().empty() && options.size_fn(value) <= options.max_size_in_bytes && @@ -2321,7 +2574,9 @@ FindCrossProgramPrefetchCandidate( const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, const MemorySpaceAssignment::Options& options) { std::vector candidates; - for (HloValue* value : alias_analysis.dataflow_analysis().values()) { + for (const HloBuffer& buffer : alias_analysis.buffers()) { + CHECK_GE(buffer.values().size(), 1); + const HloValue* value = buffer.values().at(0); if (IsCrossProgramPrefetchCandidate(*value, options)) { MemorySpaceAssignment::BufferInterval interval; interval.buffer = value; @@ -2329,6 +2584,7 @@ FindCrossProgramPrefetchCandidate( interval.start = 0; interval.end = hlo_live_range.schedule_end_time(); interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; candidates.emplace_back(interval); } } @@ -2541,15 +2797,21 @@ HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() { } std::string MemorySpaceAssignment::Allocation::ToString() const { - return absl::StrCat("Allocation in ", - memory_space_ == MemorySpace::kDefault ? "def" : "alt", - " defined at ", defining_position_.ToString()); + std::string memory_space_str = "def"; + if (memory_space_ == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); + } + return absl::StrCat("Allocation in ", memory_space_str, " defined at ", + defining_position_.ToString()); } std::string MemorySpaceAssignment::CopyAllocation::ToString() const { - return absl::StrCat("Copy Allocation in ", - memory_space_ == MemorySpace::kDefault ? "def" : "alt", - " from ", prev_allocation_.ToString()); + std::string memory_space_str = "def"; + if (memory_space_ == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); + } + return absl::StrCat("Copy Allocation in ", memory_space_str, " from ", + prev_allocation_.ToString()); } Status MemorySpaceAssignment::CopyAllocation::Process( @@ -2558,9 +2820,9 @@ Status MemorySpaceAssignment::CopyAllocation::Process( Shape shape = defining_position().shape(); HloInstruction* producing_instruction = AddGetTupleElements(); HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, producing_instruction)); + producing_instruction, is_cross_program_prefetch_)); copy_done_ = computation->AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); VLOG(4) << "Created " << copy_start_->name() diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 87f7dd2ddae..409a44d319d 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" namespace xla { @@ -105,7 +106,7 @@ class MemorySpaceAssignmentCostAnalysis { // BufferInterval. The larger this number, the higher priority it will be // placed in the alternate memory. float GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, Cache* cache = nullptr) const; // Returns the elapsed time in seconds due to compute only. @@ -199,8 +200,15 @@ class PrefetchIntervalPicker { int64 latest_end_time) const = 0; // Returns the latest time that a prefetch can start. - virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const = 0; + virtual int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const = 0; + + // Returns the preferred time that a prefetch can start. + virtual int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const = 0; // Returns the latest time that a prefetch can end that is less than or equal // to proposed_prefetch_end_time. @@ -234,7 +242,8 @@ class PrefetchIntervalPicker { // of placing the BufferInterval in the alternate memory. The larger value, // the more beneficial. virtual absl::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { return absl::nullopt; } @@ -267,8 +276,14 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; - int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const override; + int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const override; + + int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const override; void Begin(const HloUse& use, int64 start_time, int64 end_time) override; @@ -306,11 +321,18 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; - int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const override; int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const override; + int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const override; + + int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -323,7 +345,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 end_time) const override; absl::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const override; private: @@ -354,7 +376,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 end_logical_time_; int64 earliest_prefetch_time_; int64 latest_prefetch_time_; - bool using_increasing_prefetch_time_iterator_; + bool using_increasing_prefetch_time_iterator_ = true; int64 increasing_prefetch_time_iterator_; int64 decreasing_prefetch_time_iterator_; }; @@ -369,9 +391,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { class MemorySpaceAssignment { public: using Chunk = HeapSimulator::Chunk; - using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval; + using BufferInterval = + GlobalDecreasingSizeBestFitHeap::BufferInterval; using BufferIntervalCompare = - GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; + GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; using IsAllowedInAlternateMemoryFunction = std::function; @@ -379,6 +402,9 @@ class MemorySpaceAssignment { // space and a fast and small alternate memory space. enum class MemorySpace { kDefault, kAlternate }; + // Forward declaration for Allocation. + class Allocation; + // The different options to be passed to the Run() API. struct Options { // Backend-specific integer value that describes the alternate memory. @@ -424,6 +450,15 @@ class MemorySpaceAssignment { // copies or asynchronous copy ordering. int64 max_retries = 1; + // The maximum number of repacks that we are willing to perform in case we + // can't allocate a buffer due to running out of memory. If this value is + // greater than 0, repacker must be non-nullptr. + int64 max_repacks = 0; + + // The repacking algorithm to reduce fragmentation. Must be non-null if + // max_repacks is greater than 0. + MemorySpaceAssignmentRepacker* repacker = nullptr; + // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). bool allocate_across_sequential_calls = false; @@ -511,6 +546,7 @@ class MemorySpaceAssignment { const std::vector& uses() const { return uses_; } MemorySpace memory_space() const { return memory_space_; } Chunk chunk() const { return *chunk_; } + Chunk* mutable_chunk() { return &*chunk_; } void set_start_time(int64 start_time) { start_time_ = start_time; } int64 start_time() const { return start_time_; } int64 end_time() const { return end_time_; } @@ -545,12 +581,14 @@ class MemorySpaceAssignment { public: CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, - int64 end_time, int64 copy_done_schedule_before_time) + int64 end_time, int64 copy_done_schedule_before_time, + bool is_cross_program_prefetch = false) : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, start_time, end_time), prev_allocation_(prev_allocation), copy_start_schedule_after_(start_time), - copy_done_schedule_before_(copy_done_schedule_before_time) {} + copy_done_schedule_before_(copy_done_schedule_before_time), + is_cross_program_prefetch_(is_cross_program_prefetch) {} bool is_copy_allocation() const override { return true; } @@ -590,6 +628,10 @@ class MemorySpaceAssignment { copy_start_schedule_after_ = copy_start_schedule_after; } + bool is_cross_program_prefetch() const { + return is_cross_program_prefetch_; + } + bool operator==(const CopyAllocation& other) const; std::string ToString() const override; @@ -601,6 +643,7 @@ class MemorySpaceAssignment { // is before copy_done_schedule_before_. int64 copy_start_schedule_after_; int64 copy_done_schedule_before_; + bool is_cross_program_prefetch_; HloInstruction* copy_start_; HloInstruction* copy_done_; }; @@ -687,13 +730,15 @@ class MemorySpaceAssignment { std::vector aliases; }; - AllocationValue(const HloValue* value, const HloPosition& position) - : value_(value), defining_position_(position) {} + AllocationValue(const HloValue* value, const HloPosition& position, + int64 size) + : value_(value), defining_position_(position), size_(size) {} const HloPosition& defining_position() const { return defining_position_; } const HloInstruction* defining_instruction() const { return defining_position().instruction; } + int64 size() const { return size_; } const std::vector& uses() const { return uses_; } std::vector& uses() { return uses_; } const HloValue* value() const { return value_; } @@ -712,6 +757,7 @@ class MemorySpaceAssignment { private: const HloValue* value_; HloPosition defining_position_; + int64 size_; std::vector uses_; AllocationSequence allocation_sequence_; }; @@ -825,29 +871,6 @@ class MemorySpaceAssignment { absl::flat_hash_map> schedule_before_; }; -// This struct contains mandatory memory assignments at a given time. E.g., an -// input's required memory assignment time would correspond to the definition -// time of the parameter instruction, and an output's time would correspond to -// the time of last use. -struct RequiredMemoryAssignment { - MemorySpaceAssignment::MemorySpace memory_space; - int64 time; - absl::optional chunk; - - bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { - return memory_space == other.memory_space && chunk == other.chunk; - } - - bool operator==(const RequiredMemoryAssignment& other) const { - return memory_space == other.memory_space && time == other.time && - chunk == other.chunk; - } - - bool operator!=(const RequiredMemoryAssignment& other) const { - return !(*this == other); - } -}; - // A struct representing an asynchronous copy with its logical start and end // time and its destination memory space. struct AsynchronousCopy { @@ -896,7 +919,8 @@ class AsynchronousCopyOrdering { // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. -class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { +class AlternateMemoryBestFitHeap + : public GlobalDecreasingSizeBestFitHeap { public: using MemorySpace = MemorySpaceAssignment::MemorySpace; using AllocationValue = MemorySpaceAssignment::AllocationValue; @@ -923,9 +947,23 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { void AllocateCrossProgramPrefetchBuffer( HloModule* module, absl::optional prefetch_candidate); - HeapSimulator::Result Finish() override; + HeapSimulator::Result Finish() override; private: + // We inherit AllocationBlock struct to attach the Allocation information to + // make importing repacked offsets easier. + struct RepackAllocationBlock + : MemorySpaceAssignmentRepacker::AllocationBlock { + MemorySpaceAssignment::Allocation* allocation; + }; + + // A data structure we use to associate Allocation objects that are aliased + // and must get the same offset. + struct AliasedOffset { + int64 offset; + absl::flat_hash_set allocations; + }; + // An allocation request for a use segment. A use segment is the time segment // between the definition and the first use, and the time segment between the // uses of a buffer. For example, the time between the definition and Use1, is @@ -953,11 +991,101 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { int64 size; bool allow_no_copy_alternate_mem_allocation; absl::optional earliest_prefetch_time; - absl::optional preferred_offset; + AliasedOffset* preferred_offset; const MemorySpaceAssignment::AllocationValue::Use* use; MemorySpaceAssignment::AllocationValue* allocation_value; }; + // This struct contains mandatory memory assignments at a given time. E.g., an + // input's required memory assignment time would correspond to the definition + // time of the parameter instruction, and an output's time would correspond to + // the time of last use. + struct RequiredMemoryAssignment { + MemorySpaceAssignment::MemorySpace memory_space; + int64 time; + AliasedOffset* offset; + + bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { + return memory_space == other.memory_space && offset == other.offset; + } + + bool operator==(const RequiredMemoryAssignment& other) const { + return memory_space == other.memory_space && time == other.time && + offset == other.offset; + } + + bool operator!=(const RequiredMemoryAssignment& other) const { + return !(*this == other); + } + }; + + // Result of an allocation, prefetch, eviction etc. request. The result is + // either kSuccess or a bitwise OR of one or more failures. The values are + // unique powers of two. To check if a result contains a particular failure, + // use the result_is method. To add a new failure to a result, use the + // result_mark method. + enum class Result { + // Successful allocation. + kSuccess = 0, + // Allocation failed because we ran out of alternate memory. + kFailOutOfMemory = 1, + // A no-copy allocation couldn't be performed because the previous + // allocation wasn't in the alternate memory space. + kFailPrevAllocationNotInAlternateMem = 2, + // A no-copy allocation couldn't be performed because the live range was too + // long. + kFailLiveRangeTooLong = 4, + // A prefetching couldn't be performed because the live range was too short. + kFailLiveRangeTooShort = 8, + // Ran out of outstanding asynchronous copy limit either during prefetching + // or eviction. + kFailOutOfAsyncCopies = 16, + // A prefetching couldn't be performed because the asynchronous copy + // ordering was violated. + kFailViolatesAsyncCopyOrdering = 32, + // An allocation failure happened that requires uncommitting all the pending + // allocations. Usually this is due to a situation requiring an eviction but + // the eviction couldn't be performed. + kFailRequiresUncommit = 64 + }; + + // Return true if the result belongs to a failure. + static bool result_is(Result result, Result failure) { + return static_cast(result) & static_cast(failure); + } + + // Mark (bitwise OR) a failure to the result. + static Result result_mark(Result failure, Result& result) { + result = static_cast(static_cast(result) | + static_cast(failure)); + return result; + } + + // Return true if the result is a failure that requires us to uncommit pending + // chunks. + static bool result_requires_uncommit(Result result) { + return result_is(result, Result::kFailRequiresUncommit); + } + + // Return true if the result is a failure either due to running out of + // outstanding asynchronous copies or due to violating asynchronous copy + // ordering. + static bool result_failed_because_of_async_copy(Result result) { + return result_is(result, Result::kFailOutOfAsyncCopies) || + result_is(result, Result::kFailViolatesAsyncCopyOrdering); + } + + // Returns the AliasedOffset object associated with the allocation. + AliasedOffset* GetAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation); + + // If aliased_offset is non-null, this method adds the allocation to + // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds + // the allocation to this new AliasedOffset. + void CreateOrAddToAliasedOffset( + const MemorySpaceAssignment::Allocation& allocation, + AliasedOffset* aliased_offset); + // Given an allocation sequence, returns the live allocation at time with a // preference towards allocations in alternate memory. Returns nullptr if no // allocation is alive at that time. @@ -968,17 +1096,24 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; - // Given an HloValue, creates AllocationValue objects and corresponding + // Given a BufferInterval, creates AllocationValue objects and corresponding // AllocationSequences and appends them into allocation_sequence_list_. - void CreateAllocationValues(const HloValue* value, - std::vector* allocation_values); + void CreateAllocationValues( + const BufferInterval& buffer_interval, + std::vector& allocation_values) const; - // Finds allocations for colocated intervals. Colocated intervals consist of - // one or more BufferIntervals, each with a different HloValue. All of the - // intervals within colocated intervals have a must-alias relationship with - // each other. Returns true if allocation succeeded. - bool AllocateColocatedIntervals( - const std::vector& colocated_intervals); + // Given colocated intervals, populates allocation_values with the + // corresponding AllocationValue objects. + void CreateAllocationValuesFromColocatedIntervals( + absl::Span colocated_intervals, + std::vector& allocation_values); + + // Finds allocations for allocation values generated from colocated intervals. + // All of the allocation values have a must-alias relationship with each + // other. Returns either kSuccess if all of the sites could be placed in the + // alternate memory or a bitwise OR of failure reasons why they couldn't + Result AllocateAllocationValues( + absl::Span allocation_values); // Go through all the uses in the AllocationValues and find the aliasing // positions. @@ -996,24 +1131,26 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // if there is enough space and if the prefetch interval picker allows. // // If an eviction (2) was requested and was unsuccessful, this method returns - // false. This means we could not find a suitable allocation, so all previous - // allocations for this buffer must be removed and allocated in the default - // memory. Otherwise, this method returns true. - bool AllocateSegment(const AllocationRequest& request); + // Result::kFailRequiresUncommit. This means we could not find a suitable + // allocation, so all previous allocations for this buffer must be removed and + // allocated in the default memory. Otherwise, this method may return + // Result::kSuccess if the buffer could be placed in alternate memory or some + // other Result with an OR of reasons why the buffer couldn't be placed in + // alternate memory. + Result AllocateSegment(const AllocationRequest& request); - // Try allocating in alternate memory without any copies. Returns true if - // successful. - bool AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); + // Try allocating in alternate memory without any copies. + Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); - // Try evicting to default memory space. Returns true if successful. - bool Evict(const AllocationRequest& request); + // Try evicting to default memory space. + Result Evict(const AllocationRequest& request); // Returns the time a copy done of a prefetch should be scheduled. int64 FindPrefetchEndTime(const AllocationRequest& request, int64 earliest_prefetch_time) const; - // Try prefetching to alternate memory space. Returns true if successful. - bool Prefetch( + // Try prefetching to alternate memory space. + Result Prefetch( const AllocationRequest& request, const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); @@ -1021,7 +1158,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // availability if no preferred offset is given, or at the preferred_offset if // it is given. absl::optional FindBestChunkCandidate( - const AllocationRequest& request, absl::optional preferred_offset, + const AllocationRequest& request, const AliasedOffset* preferred_offset, BufferInterval* alternate_mem_interval) const; // Returns the required assignment at a particular time, if available. @@ -1043,10 +1180,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { void AddRequiredAssignment(const HloValue* value, const HloInstruction* instruction, MemorySpace memory_space, int64 time, - absl::optional chunk = absl::nullopt); + AliasedOffset* offset = nullptr); void AddRequiredAssignment(const HloInstruction* instruction, ShapeIndex index, MemorySpace memory_space, - absl::optional chunk = absl::nullopt); + AliasedOffset* offset = nullptr); // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -1081,12 +1218,24 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { absl::optional ViolatesAsyncCopyOrdering( int64 start_time, int64 end_time) const; + // Exports the allocations for repacking and puts them into the vector in the + // parameter. + void ExportAllocationsForRepacking( + std::vector& + allocations); + + // Imports repacked allocations and updates the internal data structures + // consistent with the new packing. + void ImportRepackedAllocations(); + // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations); + MemorySpaceAssignment::AllocationSequence* allocations, + AliasedOffset* aliased_offset, + bool is_cross_program_prefetch = false); // This method is used for committing the chunk candidate but adding it to // pending_chunks_ so that we can "uncommit" them in case we need to roll back @@ -1095,17 +1244,24 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const ChunkCandidate& chunk_candidate); // If we need to remove the allocations for this allocation sequence, this // removes pending chunks and asynchronous copies in the respective pending - // buffers from the interval trees. - void UncommitPendingChunks(); + // buffers from the interval trees. If an allocation request returns + // kFailRequiresUncommit, this method must be called. + void UncommitPendingChunks(absl::Span allocation_values); + + // Finalizes the allocations where they can no longer be uncommitted. + void FinalizeAllocations(absl::Span allocation_values); + + // Clears all pending chunks and asynchronous copies. + void ClearPendingChunks(); // Append buffer and allocation infos for debugging and dump it into a file, // if enabled. void AppendBufferInfoDebugString(const BufferInterval& interval, std::string* debug_str) const; void AppendAllocationInfoDebugString( - const BufferInterval& interval, + const AllocationValue& value, const MemorySpaceAssignment::Allocation& allocation, - std::string* debug_str) const; + std::string& debug_str) const; void DumpDebugStringsIfEnabled() const; // Returns the available heap size in the alternate memory. @@ -1113,6 +1269,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { return options_.max_size_in_bytes - reserved_in_bytes_; } + // Creates and returns a RepackAllocationBlock. + static RepackAllocationBlock MakeRepackAllocationBlock( + int64 start_time, int64 end_time, int64 size, int64 initial_offset, + int64 id, MemorySpaceAssignment::Allocation* allocation) { + RepackAllocationBlock allocation_block; + allocation_block.start_time = start_time; + allocation_block.end_time = end_time; + allocation_block.size = size; + allocation_block.offset = -1; + allocation_block.initial_offset = initial_offset; + allocation_block.id = id; + allocation_block.colocations = {}; + allocation_block.allocation = allocation; + return allocation_block; + } + MemorySpaceAssignment::AllocationSequence* allocations_; const MemorySpaceAssignment::Options& options_; const HloAliasAnalysis& alias_analysis_; @@ -1122,19 +1294,26 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { BufferIntervalTree prefetch_interval_tree_; BufferIntervalTree eviction_interval_tree_; AsynchronousCopyOrdering async_copy_ordering_; + // A list of RepackAllocationBlock objects that mirrors allocation sequences, + // used for repacking. We use a list here because we need pointer stability + // for aliased allocations. + std::list repack_allocation_blocks_; + int64 num_repacks_ = 0; std::vector> pending_chunks_; std::vector pending_async_copies_; std::vector> pending_required_assignments_; + // The data structure that contains AliasedOffset objects and Allocation to + // AliasedOffset map for efficient lookup. + std::list aliased_offsets_; + absl::flat_hash_map + aliased_offset_map_; // This map contains required memory assignments for HloValues (e.g., input // and outputs). absl::flat_hash_map> required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - // Variables to control allocation retries. - bool final_retry_; - bool prefetch_failed_due_to_async_copy_; // Debug strings. std::string buffer_info_str_; std::string allocation_info_str_; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc new file mode 100644 index 00000000000..53b092f1939 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h" + +#include "tensorflow/compiler/xla/service/heap_simulator.h" + +namespace xla { + +namespace { + +using AllocationBlock = MemorySpaceAssignmentRepacker::AllocationBlock; +using Type = GlobalDecreasingSizeBestFitHeap::Type; + +// This class inherits GlobalDecreasingSizeBestFitHeap and converts +// AllocationBlock objects into BufferIntervals that the heap algorithm +// understands. +class BestFitRepacker + : public GlobalDecreasingSizeBestFitHeap { + public: + BestFitRepacker(int64 max_size, int64 alignment, Type type) + : GlobalDecreasingSizeBestFitHeap(alignment, type), + max_size_(max_size) {} + + void ImportAllocationBlocks(absl::Span allocations) { + allocation_blocks_ = allocations; + for (AllocationBlock* allocation_block : allocations) { + // Check if any of the colocations are already added to buffer_intervals_. + bool need_allocation = true; + auto aliased_it = absl::c_find_if( + allocation_block->colocations, [&](AllocationBlock* search) { + return buffer_intervals_.contains(search); + }); + if (aliased_it != allocation_block->colocations.end()) { + buffer_intervals_[*aliased_it].colocations.push_back(allocation_block); + need_allocation = false; + } + buffer_intervals_[allocation_block] = {allocation_block, + allocation_block->size, + allocation_block->start_time, + allocation_block->end_time, + {}, + need_allocation}; + } + } + + bool Repack() { + Finish(); + bool success = result_.heap_size <= max_size_; + if (success) { + for (AllocationBlock* block : allocation_blocks_) { + auto chunk_it = result_.chunk_map.find(block); + if (chunk_it != result_.chunk_map.end()) { + block->offset = chunk_it->second.offset; + } + } + } + return success; + } + + private: + int64 max_size_; + absl::Span allocation_blocks_; +}; + +} // namespace + +StatusOr MemorySpaceAssignmentBestFitRepacker::Repack( + absl::Span allocations) { + BestFitRepacker best_fit_repacker = + BestFitRepacker(max_size_, alignment_, type_); + best_fit_repacker.ImportAllocationBlocks(allocations); + return best_fit_repacker.Repack(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h new file mode 100644 index 00000000000..6937b8b0e8c --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" + +namespace xla { + +// This is a repacker algorithm that wraps around best fit heap algorithm in +// heap simulator. +class MemorySpaceAssignmentBestFitRepacker + : public MemorySpaceAssignmentRepacker { + public: + using Type = GlobalDecreasingSizeBestFitHeap::Type; + + explicit MemorySpaceAssignmentBestFitRepacker( + int64 max_size, int64 alignment, + Type type = GlobalDecreasingSizeBestFitHeap::kTemporal) + : MemorySpaceAssignmentRepacker(max_size, alignment), type_(type) {} + + StatusOr Repack(absl::Span allocations) override; + + private: + Type type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc new file mode 100644 index 00000000000..44da2828eac --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { + +class MemorySpaceAssignmentBestFitRepackerTest : public ::testing::Test { + protected: + using AllocationBlock = MemorySpaceAssignmentRepacker::AllocationBlock; + + MemorySpaceAssignmentBestFitRepackerTest() : repacker_(100, 1) {} + + AllocationBlock* MakeAllocationBlock(int64 start_time, int64 end_time, + int64 size, int64 initial_offset = -1) { + allocation_blocks_.push_back({start_time, + end_time, + size, + -1, + initial_offset, + static_cast(allocation_blocks_.size()), + {}}); + AllocationBlock* block = &allocation_blocks_.back(); + block->colocations.push_back(block); + return block; + } + + std::list allocation_blocks_; + MemorySpaceAssignmentBestFitRepacker repacker_; +}; + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, Simple) { + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + EXPECT_EQ(allocation_blocks[0]->offset, 15); + EXPECT_EQ(allocation_blocks[1]->offset, 0); +} + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, Colocation) { + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(0, 2, 10)); + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + // Allocation blocks 0 and 1 are colocated. + allocation_blocks[0]->colocations.push_back(allocation_blocks[1]); + allocation_blocks[1]->colocations.push_back(allocation_blocks[0]); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + EXPECT_EQ(allocation_blocks[0]->offset, 15); + EXPECT_EQ(allocation_blocks[1]->offset, 15); + EXPECT_EQ(allocation_blocks[2]->offset, 0); +} + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, TooLarge) { + // Memory size is 100, total size of buffers is 105. + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + allocation_blocks.push_back(MakeAllocationBlock(15, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(12, 22, 50)); + allocation_blocks.push_back(MakeAllocationBlock(10, 18, 20)); + EXPECT_FALSE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + // Make sure the buffers didn't get offset assignments. + EXPECT_EQ(allocation_blocks[0]->offset, -1); + EXPECT_EQ(allocation_blocks[1]->offset, -1); + EXPECT_EQ(allocation_blocks[2]->offset, -1); + EXPECT_EQ(allocation_blocks[3]->offset, -1); + EXPECT_EQ(allocation_blocks[4]->offset, -1); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h new file mode 100644 index 00000000000..eb2f0698a95 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// An interface to define allocation repacking algorithms. +class MemorySpaceAssignmentRepacker { + public: + MemorySpaceAssignmentRepacker(int64 max_size, int64 alignment) + : max_size_(max_size), alignment_(alignment) {} + virtual ~MemorySpaceAssignmentRepacker() = default; + + // A contiguous block of allocation consisting of start and end (logical) + // times, size, and the initial offset. After repacking, if the repacking was + // successful and the allocations were modified, the offset field holds the + // new offset. To support aliased allocations, AllocationBlock also includes a + // vector of AllocationBlock pointers, called colocations. All AllocationBlock + // objects within the colocations must get the same offset. The id should be + // unique and is used to ensure determinism for comparison tie-breaker. + struct AllocationBlock { + int64 start_time; + int64 end_time; + int64 size; + int64 offset; + int64 initial_offset; + int64 id; + std::vector colocations; + + std::string ToString() const { + return absl::StrCat("[", start_time, ", ", end_time, "] : size = ", size, + ", offset = ", offset, + " initial offset = ", initial_offset); + } + + // This is required by BufferIntervalCompare as a tie breaker. Use a unique + // and deterministic id. + bool operator<(const AllocationBlock& other) const { return id < other.id; } + }; + + // Repack the AllocationBlocks provided in the parameter. Returns true if + // allocations have been modified and false if not. Returns a non-ok status if + // there was an error. + virtual StatusOr Repack(absl::Span allocations) = 0; + + protected: + int64 max_size_; + int64 alignment_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index a52a4caa12c..5af61eac5d1 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -71,19 +71,22 @@ class MemorySpaceAssignmentTest : public HloTestBase, std::unique_ptr AssignMemorySpace( HloModule* module, int64 max_outstanding_async_copies = -1, - int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2) { + int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2, + absl::optional options = absl::nullopt) { InstructionCountPrefetchIntervalPicker prefetch_interval_picker( min_prefetch_interval, max_prefetch_interval); return AssignMemorySpace(module, max_outstanding_async_copies, /*buffer_interval_compare=*/{}, - &prefetch_interval_picker); + &prefetch_interval_picker, options); } std::unique_ptr AssignMemorySpace( HloModule* module, int64 max_outstanding_async_copies, absl::optional buffer_interval_compare, - PrefetchIntervalPicker* prefetch_interval_picker) { + PrefetchIntervalPicker* prefetch_interval_picker, + absl::optional + memory_space_assignment_options = absl::nullopt) { auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; @@ -117,9 +120,15 @@ class MemorySpaceAssignmentTest : public HloTestBase, } MemorySpaceAssignment::Options options; + if (memory_space_assignment_options) { + options = *memory_space_assignment_options; + } else { + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + } + options.alternate_memory_space = kAlternateMemorySpace; - options.max_size_in_bytes = 128; - options.alignment_in_bytes = 8; options.buffer_interval_compare = buffer_interval_compare; options.prefetch_interval_picker = prefetch_interval_picker; options.size_fn = size_fn; @@ -127,7 +136,6 @@ class MemorySpaceAssignmentTest : public HloTestBase, options.max_outstanding_prefetches = max_outstanding_async_copies; options.max_outstanding_evictions = max_outstanding_async_copies; options.allocate_across_sequential_calls = GetParam(); - options.verify = true; auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); std::unique_ptr hlo_live_range = @@ -224,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase, return copies; } + int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments, + const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + // Returns the offset of the assignment, -1 if it's not in the alternate + // memory. + const HloModule* module = instruction->parent()->parent(); + auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); + HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index); + for (auto& pos_and_chunk : preset_assignments.chunks()) { + for (auto& value : buffer.values()) { + if (pos_and_chunk.first == value->defining_position()) { + return pos_and_chunk.second.offset; + } + } + } + return -1; + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -4058,6 +4084,340 @@ TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) { find_schedule_index(cos->operand(0))); } +TEST_P(MemorySpaceAssignmentTest, BitcastRoot) { + // Tests against a bug where the root of entry computation is a bitcast + // instruction and it ends up getting an allocation in the alternate memory. + absl::string_view hlo_string = R"( +HloModule primitive_computation_gather.4, is_scheduled=true + +%while_body { + %param.1 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0 + %copy.6 = s32[] copy(s32[] %get-tuple-element.32) + %constant.8 = s32[] constant(1) + %add = s32[] add(s32[] %copy.6, s32[] %constant.8) + %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1 + negate = f32[3,3,3] negate(get-tuple-element.35) + ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate) +} + +%while_cond { + %param.0 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element = s32[] get-tuple-element(%param.0), index=0 + %constant.3 = s32[] constant(3) + ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT +} + +ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] { + %constant.1 = s32[] constant(0) + %copy.11 = s32[] copy(s32[] %constant.1) + %constant = f32[] constant(0) + %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={} + %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast) + %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body + %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1 + ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(!root->shape().has_layout() || + root->shape().layout().memory_space() == kDefaultMemorySpace); +} + +// A mock MemorySpaceAssignmentRepacker class that accepst a map of +// (start_time,offset) -> new_offset values. Using this map, the repacker +// repacks the allocations to the new_offset. +class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { + public: + explicit FakeMemorySpaceAssignmentRepacker( + absl::flat_hash_map, int64>& repack_map, + std::function)> check_fun = nullptr) + : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8), + repack_map_(repack_map), + check_fun_(check_fun) {} + + StatusOr Repack(absl::Span allocations) override { + bool modified = false; + for (AllocationBlock* block : allocations) { + absl::flat_hash_set colocations; + std::string colocations_str; + for (const AllocationBlock* colocation : block->colocations) { + absl::StrAppend(&colocations_str, colocation->id, ", "); + colocations.insert(colocation->id); + } + VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time + << ", " << block->end_time << "] size: " << block->size + << " init offset: " << block->initial_offset << " colocations: {" + << colocations_str << "}"; + auto it = repack_map_.find({block->start_time, block->initial_offset}); + if (it != repack_map_.end()) { + modified = true; + block->offset = it->second; + } else { + block->offset = block->initial_offset; + } + for (AllocationBlock* colocation : block->colocations) { + if (it != repack_map_.end()) { + colocation->offset = it->second; + } else { + colocation->offset = colocation->initial_offset; + } + } + } + if (check_fun_) { + check_fun_(allocations); + } + + return modified; + } + + private: + // A map from (start_time, offset) to new_offset. + absl::flat_hash_map, int64> repack_map_; + std::function)> check_fun_; +}; + +TEST_P(MemorySpaceAssignmentTest, Repack) { + // We initially perform the following allocations at these offsets. + // + // Max memory + // ------------------------------------------- + // + // + // + // + // +------------+ + // | b | + // +------------+ + // +-------+ +------------+ + // | a | | n | + // +-------+ +------------+ + // ------------------------------------------- + // Min memory time -> + // + // Next up, we try to allocate the prefetch for m. However due to + // fragmentation, this won't be possible: + // + // Max memory + // ------------------------------------------- + // + // + // + // +---------+ + // +------------+ | + // | b | | | + // +------------+ | + // +-------+ | | +------------+ + // | a | | d | | n | + // +-------+ +---------+ +------------+ + // ------------------------------------------- + // Min memory time -> + // + // We then call repack to repack the existing allocations which allows us to + // allocate the prefetch for m: + // + // Max memory + // ------------------------------------------- + // +---------+ + // | | + // | | + // | | + // +-------+ | | + // | a | | d | + // +-------+ +---------+ + // +------------+ +------------+ + // | b | | n | + // +------------+ +------------+ + // ------------------------------------------- + // Min memory time -> + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[8,3] parameter(0) + param1 = f32[2,4] parameter(1) + a = f32[2,4] sine(param1) + b = f32[2,4] cosine(param1) + c = f32[8,3] negate(param0) + j = f32[2,4] negate(a) + d = f32[8,3] tanh(param0) + k = f32[2,4] negate(j) + l = f32[2,4] add(b, k) + m = f32[8,3] negate(d) + n = f32[2,4] sine(l) + o = f32[8,3] negate(m) + p = f32[2,4] negate(n) + q = f32[8,3] negate(m) + ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o) + } + )"; + + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& a, + const MemorySpaceAssignment::BufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kSin: + return 0; + case HloOpcode::kCos: + return 1; + case HloOpcode::kTanh: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + absl::flat_hash_map, int64> repack_map; + // Move "a" from offset 0 to 32. + repack_map[{2, 0}] = 32; + // Move "b" from offset 32 to 0. + repack_map[{3, 32}] = 0; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 1; + options.repacker = &repacker; + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + buffer_interval_compare, &prefetch_interval_picker, + options); + + // If repacking succeeds, we should find the buffer for d in alternate memory. + const HloInstruction* d = + module->entry_computation()->GetInstructionWithName("d"); + EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) { + // This test is that we are correctly exporting aliased offsets for repacking. + // In this example, the buffer produced at HLO "a" will be allocated first, + // and will consist of four allocations: + // 1) a produced in the alternate memory (and then evicted to the default + // memory). 2) a prefetched to the alternate memory to be used by q and + // while HLOs. 3) a used within the while loop body. 4) the output of while + // HLO, used by u. + // + // Since a will be allocated first (the test is crafted to prioritize sine + // HLO), all four allocations should get the same (zero) offsets. However, + // while allocations 2, 3, and 4 need to be colocated with each other, + // allocation 1 doesn't need to be colocated with the other three. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition { + param1 = (f32[2,4], f32[2,4]) parameter(0) + ROOT cond = pred[] constant(true) + } + + while_body { + param2 = (f32[2,4], f32[2,4]) parameter(0) + gte2 = f32[2,4] get-tuple-element(param2), index=0 + gte3 = f32[2,4] get-tuple-element(param2), index=1 + add = f32[2,4] add(gte2, gte3) + ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3) + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + a = f32[2,4] sine(param0) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[2,4] negate(n) + p = f32[2,4] negate(o) + q = f32[2,4] add(p, a) + tuple = (f32[2,4], f32[2,4]) tuple(q, a) + while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body + gte0 = f32[2,4] get-tuple-element(while), index=0 + gte1 = f32[2,4] get-tuple-element(while), index=1 + r = f32[2,4] negate(gte0) + s = f32[2,4] negate(r) + t = f32[2,4] negate(s) + constant = f32[] constant(0) + broadcast = f32[8,4] broadcast(constant), dimensions={} + cos = f32[8,4] cosine(broadcast) + u = f32[2,4] add(t, gte1) + v = f32[2,4] add(u, param0) + w = f32[8,4] negate(cos) + ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w) + } + )"; + + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& a, + const MemorySpaceAssignment::BufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kSin: + return 0; + case HloOpcode::kCos: + return 1; + case HloOpcode::kTanh: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + absl::flat_hash_map, int64> repack_map; + + // Expect that of the four separate allocations for the "a" buffer, the first + // and the next three are in separate colocations. + auto check_fun = + [](absl::Span + allocations) { + EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 || + allocations.at(0)->colocations.size() == 3); + EXPECT_EQ(allocations.at(1)->colocations.size(), 3); + EXPECT_EQ(allocations.at(2)->colocations.size(), 3); + EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 || + allocations.at(3)->colocations.size() == 3); + }; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map, check_fun); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 1; + options.repacker = &repacker; + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + buffer_interval_compare, &prefetch_interval_picker, + options); +} + TEST_P(MemorySpaceAssignmentTest, Determinism) { // Run memory space assignment a few times to make sure every time it compiles // to the same thing. @@ -4073,6 +4433,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) { } } +TEST_P(MemorySpaceAssignmentTest, InPlaceOp) { + // Tests that in-place ops like DynamicUpdateSlice get the same allocation as + // its input. + absl::string_view hlo_string = R"( +HloModule Module, is_scheduled=true + +fused_computation { + param0 = f32[2,3] parameter(0) + constant.1 = f32[] constant(0) + broadcast = f32[2,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3) +} + +ENTRY main { + param = f32[2,3] parameter(0) + negate = f32[2,3] negate(param) + fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation + ROOT add = f32[2,3] add(fusion, fusion) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpace(module.get()); + HloInstruction* negate_instruction = + module->entry_computation()->GetInstructionWithName("negate"); + int64 negate_offset = + GetAlternateMemoryOffset(*preset_assignments, negate_instruction); + HloInstruction* fusion_instruction = + module->entry_computation()->GetInstructionWithName("fusion"); + int64 fusion_offset = + GetAlternateMemoryOffset(*preset_assignments, fusion_instruction); + // We expect negate and fusion to get the same offsets. + EXPECT_EQ(negate_offset, fusion_offset); + const bool allocate_across_sequential_calls = GetParam(); + if (allocate_across_sequential_calls) { + EXPECT_NE(negate_offset, -1); + } +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); @@ -4354,6 +4755,166 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) { EXPECT_EQ(cross_program_prefetches.size(), 0); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShapeWithLayout( + F32, {kFeature, kOutput}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, lhs, rhs, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) { + // This test is for checking if the cross-program-prefetched buffer is freed + // after its last use and there is an end-of-program prefetch. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY CrossProgramPrefetch { + p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0) + get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0 + get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1 + dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {1}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_end_of_program_prefetch), + 1); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { + // This tests the scenario that the cross-program-prefetched buffer is used + // again close to the end of the computation. In this case, it is better not + // to free the buffer. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY CrossProgramPrefetch { + p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0) + get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0 + get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1 + dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {1}); + // Expect that there is one prefetch that use this value, the cross-program + // prefetch. There shouldn't be an end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_end_of_program_prefetch), + 0); +} + using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { @@ -4578,11 +5139,12 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { HloInstruction* root = module->entry_computation()->root_instruction(); const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + const Shape& shape = root->operand(1)->shape(); // We expect the root's latest prefetch start time to be before the while loop // (logical time 4). - EXPECT_EQ(interval_picker.LatestPrefetchStartTime(use, /*start_time=*/0, - /*end_time=*/23), + EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/23, &use), 4); } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index 0215f007c9c..0c44ae0d766 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -17,21 +17,21 @@ limitations under the License. namespace xla { -bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( + const HloValue* value) { // If the buffer is a tuple, don't use this algorithm for now. The buffers // that are pointed to by the tuple will still use this algorithm. Because // tuples are cheap to place in the alternate memory (they are just pointers) // we don't need to use prefetch/evict logic. - if (interval.buffer->shape().IsTuple()) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (value->shape().IsTuple()) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a tuple."; return false; } // Don't place scalars in the alternate memory. - if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (ShapeUtil::IsEffectiveScalar(value->shape())) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a scalar."; return false; } @@ -44,10 +44,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // allocate TupleSelect in the alternate memory space. // TODO(berkin): Not allocating add-dependencies either since they need to be // treated specially. We should revisit this later. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kTupleSelect || position.instruction->opcode() == HloOpcode::kAddDependency) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it has a tuple-select or " << "add-dependency position."; return false; @@ -56,18 +56,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // Send and Recv HLOs return a request identifier. These should not be // allocated in the alternate memory. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if ((position.instruction->opcode() == HloOpcode::kSend || position.instruction->opcode() == HloOpcode::kRecv)) { // TODO(berkin): Send/recv buffers need a stable buffer allocation // throughout sending/receiving. Disable memory space allocation for these // for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a send/recv buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a request identifier for " "send/recv."; return false; @@ -78,11 +78,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { // Disable memory space allocation for these for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } @@ -92,4 +92,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( return true; } +bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { + return IsValueAllowedInAlternateMemory(interval.buffer) && + absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.h b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h index 651ac107c25..082efa5eb64 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h @@ -26,7 +26,11 @@ class MemorySpaceAssignmentUtils { // Returns true if this buffer is allowed to be placed in the alternate // memory. static bool IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval); + const GlobalDecreasingSizeBestFitHeap::BufferInterval& + interval); + + // Returns true if the HloValue is allowed to be placed in alternate memory. + static bool IsValueAllowedInAlternateMemory(const HloValue* value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 2bcf5fa7dae..1990e962802 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -41,9 +41,12 @@ cc_library( srcs = ["emission_context.cc"], hdrs = ["emission_context.h"], deps = [ + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla/service:hlo", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -82,8 +85,9 @@ cc_library( ":kernel_lowering", ":lhlo_dialect_emitter", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Core", "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", @@ -148,51 +152,68 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Core", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:StandardOps", ], ) +cc_library( + name = "passes", + srcs = ["passes.cc"], + hdrs = ["passes.h"], + deps = [ + "//tensorflow/compiler/mlir/hlo:lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "kernel_lowering", srcs = ["kernel_lowering.cc"], hdrs = ["kernel_lowering.h"], deps = [ + ":passes", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", - "@llvm-project//mlir:Affine", "@llvm-project//mlir:AffineToStandardTransforms", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", - "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index ca979262df0..06c7ebd1099 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" #include "absl/strings/substitute.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -25,6 +28,8 @@ namespace mlir_gpu { EmissionContext::EmissionContext(std::unique_ptr module) : module_(std::move(module)), context_() { + context_.loadDialect(); error_handler_ = [](const ErrorMap& instructions_with_error, HloModule* module) { std::set computations_with_error; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD index 8f56548ce77..eb7cd2115f3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD @@ -72,12 +72,14 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform:test", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:AffineToStandardTransforms", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index d5cad385324..c868d205310 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -46,6 +48,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { hlo_module.entry_computation()->root_instruction(); mlir::MLIRContext context; + context.loadDialect(); mlir::OwningModuleRef mlir_module( mlir::ModuleOp::create(mlir::UnknownLoc::get(&context))); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 2e3fa00ca86..a9e4a2390fd 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -18,423 +18,33 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Region.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/BufferPlacement.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace mlir_gpu { -namespace { - -using ::mlir::lmhlo::FusionOp; - -// Replaces a FusionOp by the operations contained in its region. -struct FusionOpRemover - : public mlir::PassWrapper { - void runOnFunction() override { - getFunction().walk([&](FusionOp op) { - mlir::OpBuilder builder(op); - // FusionOp has a single region with a single block, so we can just walk - // over it and clone operations to the outside. - mlir::BlockAndValueMapping mapping; - for (auto& nested_op : op.region().front().without_terminator()) { - auto clone = builder.clone(nested_op, mapping); - for (auto pair : - llvm::zip(nested_op.getResults(), clone->getResults())) { - mapping.map(std::get<0>(pair), std::get<1>(pair)); - } - } - op.erase(); - }); - } -}; - -// Simple pass that replaces a load that immediately follows a store to the -// same address with the stored value. This needs generalization. -struct StoreForwardingPass - : mlir::PassWrapper { - mlir::StoreOp findStore(mlir::Operation* op, - std::function matches) { - // Search from op upwards in the current block. - mlir::Block* block = op->getBlock(); - auto startFromIt = - std::find_if(block->rbegin(), block->rend(), - [op](mlir::Operation& other) { return &other == op; }); - for (auto storeOpIt = startFromIt; storeOpIt != block->rend(); - ++storeOpIt) { - auto storeOp = llvm::dyn_cast(&*(storeOpIt)); - if (!storeOp || !matches(storeOp)) { - continue; - } - - return storeOp; - } - // No store operation found. Continue search outside of the parallel - // loop if block is in a parallel loop. - if (auto parallelOp = - llvm::dyn_cast(block->getParentOp())) { - return findStore(parallelOp.getOperation(), matches); - } - return {}; - } - - // Recursively search defining ops for AllocOp. Return either AllocOp if it is - // found or nullptr. - mlir::Operation* SearchAllocOp(mlir::Value memref) { - mlir::Operation* defOp = memref.getDefiningOp(); - while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { - defOp = subviewOp.source().getDefiningOp(); - } - if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { - return allocOp.getOperation(); - } - return nullptr; - } - - // Retrieves AllocOp from the cache or actually looks for it. - mlir::Operation* GetAllocOp( - mlir::Value memref, - llvm::DenseMap* memrefToAllocOp) { - auto allocOpIt = memrefToAllocOp->find(memref); - if (allocOpIt != memrefToAllocOp->end()) { - return allocOpIt->second; - } - auto allocOp = SearchAllocOp(memref); - memrefToAllocOp->insert({memref, allocOp}); - return allocOp; - } - - void runOnFunction() override { - llvm::DenseMap memrefToAllocOp; - - getFunction().walk([&](mlir::LoadOp loadOp) { - auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) { - mlir::Operation* storeOpAlloc = - GetAllocOp(storeOp.memref(), &memrefToAllocOp); - mlir::Operation* loadOpAlloc = - GetAllocOp(loadOp.memref(), &memrefToAllocOp); - return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc); - }); - if (!storeOp) { - return; - } - auto storeIndices = storeOp.getIndices(); - auto loadIndices = loadOp.getIndices(); - if (!std::equal(storeIndices.begin(), storeIndices.end(), - loadIndices.begin(), loadIndices.end())) { - return; - } - loadOp.replaceAllUsesWith(storeOp.getValueToStore()); - loadOp.erase(); - }); - } -}; - -// Simple pass that removes temporary buffers that are only written to but -// never read from or that are read but the read value is not used. -// Needs an analysis that proves that loads and stores are side-effect free -// (in bounds, no aliasing, etc.). -struct DeadTempBufferRemoval - : mlir::PassWrapper { - bool operationConsideredDead(mlir::Operation* op) { - for (auto result : op->getResults()) { - if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { - // Store and Dealloc is OK. - if (llvm::isa(op)) { - return true; - } - // Load without uses is also ok. - if (auto loadOp = llvm::dyn_cast(op)) { - return loadOp.use_empty(); - } - // Subview is ok if it is dead itself. - if (llvm::isa(op)) { - return operationConsideredDead(op); - } - return false; - })) { - return false; - } - } - return true; - } - - void recursiveErase(mlir::Operation* op, - llvm::SmallVectorImpl* erase_list) { - for (auto result : op->getResults()) { - for (auto user : llvm::make_early_inc_range(result.getUsers())) { - recursiveErase(user, erase_list); - } - } - erase_list->push_back(op); - } - - void runOnFunction() override { - llvm::SmallVector dead_ops; - getFunction().walk([&](mlir::AllocOp allocOp) { - if (!operationConsideredDead(allocOp)) { - return; - } - - // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp, &dead_ops); - }); - for (auto op : dead_ops) { - op->erase(); - } - } -}; - -// TODO(herhut): Move this to MLIR core. -struct MoveScalarComputationsIntoGpuLaunch - : mlir::PassWrapper { - static bool isInliningBeneficiary(mlir::Operation* op) { - return llvm::isa(op); - } - - static bool extractBeneficiaryOps( - mlir::Operation* op, llvm::SmallVectorImpl* ops, - llvm::SetVector args) { - if (!isInliningBeneficiary(op)) { - return false; - } - - ops->push_back(op); - for (auto operand : op->getOperands()) { - // It is an existing arg, keep going. - if (args.count(operand)) { - continue; - } - mlir::Operation* definingOp = operand.getDefiningOp(); - if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { - return false; - } - } - return true; - } - - static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { - llvm::SetVector used_above; - mlir::getUsedValuesDefinedAbove(launch.body(), used_above); - mlir::BlockAndValueMapping inlined_map; - for (mlir::Value v : used_above) { - llvm::SmallVector ops_to_move; - mlir::Operation* definingOp = v.getDefiningOp(); - if (definingOp && - extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { - mlir::OpBuilder b(launch.body()); - for (mlir::Operation* op : llvm::reverse(ops_to_move)) { - auto result = b.clone(*op, inlined_map); - for (auto pair : llvm::zip(op->getResults(), result->getResults())) { - mlir::replaceAllUsesInRegionWith(std::get<0>(pair), - std::get<1>(pair), launch.body()); - } - inlined_map.map(op->getResults(), result->getResults()); - } - } - } - } - - void runOnFunction() override { - mlir::FuncOp fun = getFunction(); - fun.walk( - [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); - } -}; - -// Sort the operands to the kernel for a deterministic order. First operands -// that are defined by function arguments, followed by operands that are -// returned from the function. This only works for simple functions without -// control flow and can be used in cases where the kernel is extracted and used -// independently of the host-side code. -struct RewriteKernelSignature - : mlir::PassWrapper { - void runOnFunction() override { - mlir::FuncOp func = getFunction(); - mlir::ModuleOp module = func.getParentOfType(); - getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) { - mlir::gpu::GPUFuncOp kernel = - module.lookupSymbol(launchOp.kernel()); - - if (kernel.getNumFuncArguments() != - func.getNumArguments() + func.getNumResults()) { - kernel.emitError() - << "number of kernel arguments does not match number" - << "of arguments and results of surrounding function"; - signalPassFailure(); - return; - } - if (!llvm::hasSingleElement(func)) { - func.emitError() << "surrounding function has more than one block"; - signalPassFailure(); - return; - } - - // Compute a map from function arguments to kernel function operands. - mlir::BlockAndValueMapping func_to_kernel; - for (mlir::BlockArgument arg : func.getArguments()) { - for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { - if (launchOp.getKernelOperand(i) == arg) { - func_to_kernel.map(arg, kernel.getArgument(i)); - break; - } - } - } - // Also add function results that are computed by the launch. - mlir::Operation* returnOp = func.getBody().back().getTerminator(); - for (mlir::Value result : returnOp->getOperands()) { - for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { - if (launchOp.getKernelOperand(i) == result) { - func_to_kernel.map(result, kernel.getArgument(i)); - break; - } - } - } - - // Create a new kernel function with modified signature. It will have the - // parameters and result types of the original funcion as its parameter - // type and otherwise will be void. - auto gpu_module = kernel.getParentOfType(); - mlir::OpBuilder kernel_builder(gpu_module.body()); - auto operand_types = llvm::to_vector<4>(llvm::concat( - func.getType().getInputs(), func.getType().getResults())); - auto new_kernel = kernel_builder.create( - kernel.getLoc(), kernel.getName(), - kernel_builder.getFunctionType(operand_types, {})); - new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(), - kernel_builder.getUnitAttr()); - - // Create a map from old kernel argument to new one. - mlir::BlockAndValueMapping old_kernel_to_new; - for (int i = 0, e = func.getNumArguments(); i < e; ++i) { - mlir::Value func_arg = func.getArgument(i); - mlir::Value new_kernel_arg = new_kernel.getArgument(i); - mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg); - if (!old_kernel_arg) { - kernel.emitOpError() - << "argument " << i - << " to containing function is not an argument to the kernel"; - signalPassFailure(); - return; - } - old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); - } - for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) { - mlir::Value ret_op = returnOp->getOperand(i); - mlir::Value new_kernel_arg = - new_kernel.getArgument(func.getNumArguments() + i); - mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op); - if (!old_kernel_arg) { - kernel.emitOpError() - << "result " << i - << " of containing function is not an argument to the kernel"; - signalPassFailure(); - return; - } - old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); - } - // Steal the body by appending the blocks and inserting a branch. - kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new); - kernel_builder.setInsertionPointToEnd(&new_kernel.body().front()); - kernel_builder.create( - new_kernel.getLoc(), &*std::next(new_kernel.body().begin())); - // Now create a new launchOp calling the new kernel. We need to forward - // the arguments of the surrounding function and operands to the return. - mlir::SmallVector new_operands; - new_operands.reserve(new_kernel.getNumFuncArguments()); - new_operands.append(func.args_begin(), func.args_end()); - new_operands.append(returnOp->operand_begin(), returnOp->operand_end()); - mlir::OpBuilder launch_builder(launchOp); - launch_builder.create( - launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(), - launchOp.getBlockSizeOperandValues(), new_operands); - // Launch does not have results, so we can just erase it. And the kernel - // also needs to go. - launchOp.erase(); - kernel.erase(); - }); - } -}; - -// Extract_element(mhlo_scalars_to_dimension_tensor(v_i), i) -> v_i -// -// We need to direct fusion to the inner loops. This cannot be done with -// a passmanager alone ATM, as nested pass managers require operations to -// be closed from above. -struct MapParallelLoops - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); - } -}; - -// We need to direct fusion to the inner loops. This cannot be done with -// a passmanager alone ATM, as nested pass managers require operations to -// be closed from above. -struct FuseInnerParallelLoops - : public mlir::PassWrapper { - void runOnFunction() override { - getFunction().walk([](mlir::scf::ParallelOp op) { - mlir::scf::naivelyFuseParallelOps(op.region()); - }); - } -}; - -// Collapse all loop dimension into the first one. -struct ParallelLoopCollapsingToFirstDim - : public mlir::PassWrapper> { - void runOnOperation() override { - mlir::Operation* module = getOperation(); - - module->walk([&](mlir::scf::ParallelOp op) { - unsigned num_loops = op.getNumLoops(); - std::vector combinedLoops; - combinedLoops.reserve(num_loops); - for (unsigned i = 0; i < num_loops; ++i) { - combinedLoops.push_back(i); - } - mlir::collapseParallelLoops(op, {combinedLoops}); - }); - } -}; -} // namespace Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { mlir::PassManager pm(module.getContext()); @@ -461,9 +71,9 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Moving `AllocOp`s and inserting missing `DeallocOp`s pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. - pm.addPass(absl::make_unique()); + pm.addPass(createFusionOpRemoverPass()); // Remove unnecessary LHLO copies. - pm.addPass(::mlir::lmhlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::createCopyRemovalPass()); // Transform LHLO operations to LinAlg. pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. @@ -479,26 +89,26 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Fuse the inner-most loops. - pm.addPass(absl::make_unique()); + pm.addPass(createFuseInnerParallelLoopsPass()); // Run CSE to ensure that loads and stores to the same subview get // recognized as such. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Forward stores to buffers to loads. - pm.addPass(absl::make_unique()); + pm.addPass(createStoreForwardingPass()); // Remove now unused temporary buffers. - pm.addPass(absl::make_unique()); + pm.addPass(createDeadTempBufferRemovalPass()); if (!options.unroll_factors.empty()) { pm.addPass(::mlir::createParallelLoopTilingPass(as_int64)); } // Project all loop dimensions to X if necessary. if (options.collapse_parallel_loops) { - pm.addPass(absl::make_unique()); + pm.addPass(createParallelLoopCollapsingToFirstDimPass()); } // Some basic cleanup. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Greedily map the remaining loop to GPU hardware dimensions. - pm.addPass(absl::make_unique()); + pm.addPass(createMapParallelLoopsPass()); // Apply the mapping. pm.addPass(mlir::createParallelLoopToGpuPass()); // Some basic cleanup. @@ -515,13 +125,13 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { ::mlir::mhlo::createLegalizeTanhToApproximationPass()); } // Move scalar operations into the launch to ensure smaller signatures. - pm.addPass(absl::make_unique()); + pm.addPass(createMoveScalarComputationsIntoGpuLaunchPass()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); // Make sure the kernel signature resembled the original function's // signature if (options.rewrite_signature) { - pm.addPass(absl::make_unique()); + pm.addPass(createRewriteKernelSignaturePass()); } if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); @@ -536,6 +146,10 @@ namespace { class LowerToNVVMPass : public ::mlir::PassWrapper< LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { ::mlir::gpu::GPUModuleOp m = getOperation(); @@ -585,6 +199,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { return Status::OK(); } +namespace { + +/// A pass that does the final lowering to ROCDL. It collects all the patterns +/// that are currently required, currently mixing std, linalg and gpu. +class LowerToROCDLPass + : public ::mlir::PassWrapper< + LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + public: + void runOnOperation() override { + ::mlir::gpu::GPUModuleOp m = getOperation(); + + ::mlir::OwningRewritePatternList patterns; + ::mlir::populateGpuRewritePatterns(m.getContext(), patterns); + ::mlir::applyPatternsAndFoldGreedily(m, patterns); + patterns.clear(); + + ::mlir::LLVMTypeConverter converter(m.getContext()); + ::mlir::populateStdToLLVMConversionPatterns(converter, patterns); + // TODO(b/145824979) Remove linalg once sliceop is in std. + ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns, + &getContext()); + ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns); + ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); + + ::mlir::ConversionTarget target(getContext()); + target.addIllegalDialect<::mlir::gpu::GPUDialect>(); + target + .addIllegalOp(); + target.addIllegalOp(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); + // TODO(csigg): Remove once we support replacing non-root ops. + target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp, + ::mlir::gpu::YieldOp>(); + if (failed(mlir::applyFullConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) { + // We cannot verify as the signature of the kernel is rewritten. + ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); + applyPassManagerCLOptions(pm); + + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, + /*out=*/llvm::dbgs()); + + // Rewrite kernel functions to LLVM IR. + auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernelPm.addPass(::mlir::createLowerToCFGPass()); + kernelPm.addPass(absl::make_unique()); + + // Some basic cleanup. + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Remove all location information to prevent a debug build. + kernelPm.addPass(::mlir::createStripDebugInfoPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering to ROCDL IR failed."); + } + return Status::OK(); +} + StatusOr ExtractKernelModule(mlir::ModuleOp module) { auto kernelModule = ::mlir::ModuleOp::create(module.getLoc()); // TODO(b/137624192): This also needs to resolve naming conflicts. @@ -595,5 +288,6 @@ StatusOr ExtractKernelModule(mlir::ModuleOp module) { }); return kernelModule; } + } // namespace mlir_gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index bd633bb06cb..290550142ec 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, Status LowerKernelBodiesToNVVM(mlir::ModuleOp module); +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module); + StatusOr ExtractKernelModule(mlir::ModuleOp module); } // namespace mlir_gpu diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 194eb4618d3..b275dd4525f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/IR/DataLayout.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -203,9 +204,13 @@ LhloDialectEmitter::LhloDialectEmitter( builder_(mlir_module_.getContext()), buffer_assignment_(assignment), platform_(platform) { - LLVMDialect* llvmDialect = - mlir_module.getContext()->getRegisteredDialect(); - pointer_size_ = llvmDialect->getLLVMModule().getDataLayout().getPointerSize(); + llvm::DataLayout data_layout(""); + if (auto data_layout_attr = mlir_module.getAttrOfType( + mlir::LLVM::LLVMDialect::getDataLayoutAttrName())) { + data_layout.reset(data_layout_attr.getValue()); + } + + pointer_size_ = data_layout.getPointerSize(); } void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr thunk) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index 458522f89e6..26c9e155c0c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -25,23 +25,8 @@ limitations under the License. namespace xla { namespace mlir_gpu { -namespace { -using ::mlir::MLIRContext; -using ::mlir::LLVM::LLVMDialect; - -int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { - LLVMDialect* dialect = context->getRegisteredDialect(); - llvm::Module& module = dialect->getLLVMModule(); - module.setTargetTriple(gpu::nvptx::kTargetTriple); - module.setDataLayout(gpu::nvptx::kDataLayout); - return module.getDataLayout().getPointerSize(); -} - -} // namespace - -MlirCompiler::MlirCompiler() - : pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {} +MlirCompiler::MlirCompiler() : data_layout_("") {} se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index a7b2f9446fa..261e249c0a1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ +#include "llvm/IR/DataLayout.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/compiler.h" @@ -58,7 +59,7 @@ class MlirCompiler : public Compiler { protected: ::mlir::MLIRContext context_; - int64 pointer_size_; + llvm::DataLayout data_layout_; IRHook module_hook_; ErrorHandler error_handler_; }; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 2c2076bbd97..2e94c1a54f2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "llvm/IR/LLVMContext.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project @@ -103,7 +104,7 @@ class MlirCompilerImpl : public MlirCompiler { const AotCompilationOptions& options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - int64 pointer_size = pointer_size_; + int64 pointer_size = data_layout_.getPointerSize(); return [pointer_size](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, pointer_size); }; @@ -292,10 +293,10 @@ Status InsertBufferLoadPreduleIntoKernel( BufferAssignment* assignment, const std::vector& buffers) { mlir::OpBuilder builder(kernel.getBody()); - auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); - auto offset_type = LLVMType::getInt64Ty(llvm_dialect); - auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); - auto void_type = LLVMType::getVoidTy(llvm_dialect); + auto* context = kernel.getContext(); + auto offset_type = LLVMType::getInt64Ty(context); + auto ptr_type = LLVMType::getInt8PtrTy(context); + auto void_type = LLVMType::getVoidTy(context); auto loc = kernel.getLoc(); auto num_original_args = kernel.getNumArguments(); @@ -461,9 +462,9 @@ StatusOr> MlirCompilerImpl::RunBackend( // must also be used to determine the thunk launch schedule. std::unique_ptr stream_assignment = xla::gpu::AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, + data_layout_.getPointerSize())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -543,7 +544,11 @@ StatusOr> MlirCompilerImpl::RunBackend( TF_RETURN_IF_ERROR( module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + // Translate to LLVM IR in a fresh context. The module is further translated + // to textual PTX and a CUBIN blob so there is no need for the context to live + // longer than this function. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); if (!llvmModule) { return InternalError("Translation to LLVM failed"); @@ -575,7 +580,7 @@ StatusOr> MlirCompilerImpl::RunBackend( return {absl::make_unique( ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), emission_context.releaseHloModule(), std::move(buffer_assignment), - nullptr, nullptr)}; + nullptr, nullptr, std::vector())}; } StatusOr>> MlirCompilerImpl::Compile( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc new file mode 100644 index 00000000000..887f14e90d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc @@ -0,0 +1,423 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" + +#include "absl/memory/memory.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +struct FusionOpRemoverPass + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([&](mlir::lmhlo::FusionOp op) { + mlir::OpBuilder builder(op); + // FusionOp has a single region with a single block, so we can just walk + // over it and clone operations to the outside. + mlir::BlockAndValueMapping mapping; + for (auto& nested_op : op.region().front().without_terminator()) { + auto clone = builder.clone(nested_op, mapping); + for (auto pair : + llvm::zip(nested_op.getResults(), clone->getResults())) { + mapping.map(std::get<0>(pair), std::get<1>(pair)); + } + } + op.erase(); + }); + } +}; + +struct StoreForwardingPass + : mlir::PassWrapper { + mlir::StoreOp findStore(mlir::Operation* op, + std::function matches) { + // Search from op upwards in the current block. + mlir::Block* block = op->getBlock(); + auto startFromIt = + std::find_if(block->rbegin(), block->rend(), + [op](mlir::Operation& other) { return &other == op; }); + for (auto storeOpIt = startFromIt; storeOpIt != block->rend(); + ++storeOpIt) { + auto storeOp = llvm::dyn_cast(&*(storeOpIt)); + if (!storeOp || !matches(storeOp)) { + continue; + } + + return storeOp; + } + // No store operation found. Continue search outside of the parallel + // loop if block is in a parallel loop. + if (auto parallelOp = + llvm::dyn_cast(block->getParentOp())) { + return findStore(parallelOp.getOperation(), matches); + } + return {}; + } + + // Recursively search defining ops for AllocOp. Return either AllocOp if it is + // found or nullptr. + mlir::Operation* SearchAllocOp(mlir::Value memref) { + mlir::Operation* defOp = memref.getDefiningOp(); + while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { + defOp = subviewOp.source().getDefiningOp(); + } + if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { + return allocOp.getOperation(); + } + return nullptr; + } + + // Retrieves AllocOp from the cache or actually looks for it. + mlir::Operation* GetAllocOp( + mlir::Value memref, + llvm::DenseMap* memrefToAllocOp) { + auto allocOpIt = memrefToAllocOp->find(memref); + if (allocOpIt != memrefToAllocOp->end()) { + return allocOpIt->second; + } + auto allocOp = SearchAllocOp(memref); + memrefToAllocOp->insert({memref, allocOp}); + return allocOp; + } + + void runOnFunction() override { + llvm::DenseMap memrefToAllocOp; + + getFunction().walk([&](mlir::LoadOp loadOp) { + auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) { + mlir::Operation* storeOpAlloc = + GetAllocOp(storeOp.memref(), &memrefToAllocOp); + mlir::Operation* loadOpAlloc = + GetAllocOp(loadOp.memref(), &memrefToAllocOp); + return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc); + }); + if (!storeOp) { + return; + } + auto storeIndices = storeOp.getIndices(); + auto loadIndices = loadOp.getIndices(); + if (!std::equal(storeIndices.begin(), storeIndices.end(), + loadIndices.begin(), loadIndices.end())) { + return; + } + loadOp.replaceAllUsesWith(storeOp.getValueToStore()); + loadOp.erase(); + }); + } +}; + +struct DeadTempBufferRemovalPass + : mlir::PassWrapper { + bool operationConsideredDead(mlir::Operation* op) { + for (auto result : op->getResults()) { + if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { + // Store and Dealloc is OK. + if (llvm::isa(op)) { + return true; + } + // Load without uses is also ok. + if (auto loadOp = llvm::dyn_cast(op)) { + return loadOp.use_empty(); + } + // Subview is ok if it is dead itself. + if (llvm::isa(op)) { + return operationConsideredDead(op); + } + return false; + })) { + return false; + } + } + return true; + } + + void recursiveErase(mlir::Operation* op, + llvm::SmallVectorImpl* erase_list) { + for (auto result : op->getResults()) { + for (auto user : llvm::make_early_inc_range(result.getUsers())) { + recursiveErase(user, erase_list); + } + } + erase_list->push_back(op); + } + + void runOnFunction() override { + llvm::SmallVector dead_ops; + getFunction().walk([&](mlir::AllocOp allocOp) { + if (!operationConsideredDead(allocOp)) { + return; + } + + // TODO(herhut): There should be a generic helper for this. + recursiveErase(allocOp, &dead_ops); + }); + for (auto op : dead_ops) { + op->erase(); + } + } +}; + +struct MoveScalarComputationsIntoGpuLaunchPass + : mlir::PassWrapper { + static bool isInliningBeneficiary(mlir::Operation* op) { + return llvm::isa(op); + } + + static bool extractBeneficiaryOps( + mlir::Operation* op, llvm::SmallVectorImpl* ops, + llvm::SetVector args) { + if (!isInliningBeneficiary(op)) { + return false; + } + + ops->push_back(op); + for (auto operand : op->getOperands()) { + // It is an existing arg, keep going. + if (args.count(operand)) { + continue; + } + mlir::Operation* definingOp = operand.getDefiningOp(); + if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { + return false; + } + } + return true; + } + + static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { + llvm::SetVector used_above; + mlir::getUsedValuesDefinedAbove(launch.body(), used_above); + mlir::BlockAndValueMapping inlined_map; + for (mlir::Value v : used_above) { + llvm::SmallVector ops_to_move; + mlir::Operation* definingOp = v.getDefiningOp(); + if (definingOp && + extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { + mlir::OpBuilder b(launch.body()); + for (mlir::Operation* op : llvm::reverse(ops_to_move)) { + auto result = b.clone(*op, inlined_map); + for (auto pair : llvm::zip(op->getResults(), result->getResults())) { + mlir::replaceAllUsesInRegionWith(std::get<0>(pair), + std::get<1>(pair), launch.body()); + } + inlined_map.map(op->getResults(), result->getResults()); + } + } + } + } + + void runOnFunction() override { + mlir::FuncOp fun = getFunction(); + fun.walk( + [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); + } +}; + +struct RewriteKernelSignaturePass + : mlir::PassWrapper { + void runOnFunction() override { + mlir::FuncOp func = getFunction(); + mlir::ModuleOp module = func.getParentOfType(); + getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) { + mlir::gpu::GPUFuncOp kernel = + module.lookupSymbol(launchOp.kernel()); + + if (kernel.getNumFuncArguments() != + func.getNumArguments() + func.getNumResults()) { + kernel.emitError() + << "number of kernel arguments does not match number" + << "of arguments and results of surrounding function"; + signalPassFailure(); + return; + } + if (!llvm::hasSingleElement(func)) { + func.emitError() << "surrounding function has more than one block"; + signalPassFailure(); + return; + } + + // Compute a map from function arguments to kernel function operands. + mlir::BlockAndValueMapping func_to_kernel; + for (mlir::BlockArgument arg : func.getArguments()) { + for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { + if (launchOp.getKernelOperand(i) == arg) { + func_to_kernel.map(arg, kernel.getArgument(i)); + break; + } + } + } + // Also add function results that are computed by the launch. + mlir::Operation* returnOp = func.getBody().back().getTerminator(); + for (mlir::Value result : returnOp->getOperands()) { + for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { + if (launchOp.getKernelOperand(i) == result) { + func_to_kernel.map(result, kernel.getArgument(i)); + break; + } + } + } + + // Create a new kernel function with modified signature. It will have the + // parameters and result types of the original funcion as its parameter + // type and otherwise will be void. + auto gpu_module = kernel.getParentOfType(); + mlir::OpBuilder kernel_builder(gpu_module.body()); + auto operand_types = llvm::to_vector<4>(llvm::concat( + func.getType().getInputs(), func.getType().getResults())); + auto new_kernel = kernel_builder.create( + kernel.getLoc(), kernel.getName(), + kernel_builder.getFunctionType(operand_types, {})); + new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(), + kernel_builder.getUnitAttr()); + + // Create a map from old kernel argument to new one. + mlir::BlockAndValueMapping old_kernel_to_new; + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + mlir::Value func_arg = func.getArgument(i); + mlir::Value new_kernel_arg = new_kernel.getArgument(i); + mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg); + if (!old_kernel_arg) { + kernel.emitOpError() + << "argument " << i + << " to containing function is not an argument to the kernel"; + signalPassFailure(); + return; + } + old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); + } + for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) { + mlir::Value ret_op = returnOp->getOperand(i); + mlir::Value new_kernel_arg = + new_kernel.getArgument(func.getNumArguments() + i); + mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op); + if (!old_kernel_arg) { + kernel.emitOpError() + << "result " << i + << " of containing function is not an argument to the kernel"; + signalPassFailure(); + return; + } + old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); + } + // Steal the body by appending the blocks and inserting a branch. + kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new); + kernel_builder.setInsertionPointToEnd(&new_kernel.body().front()); + kernel_builder.create( + new_kernel.getLoc(), &*std::next(new_kernel.body().begin())); + // Now create a new launchOp calling the new kernel. We need to forward + // the arguments of the surrounding function and operands to the return. + mlir::SmallVector new_operands; + new_operands.reserve(new_kernel.getNumFuncArguments()); + new_operands.append(func.args_begin(), func.args_end()); + new_operands.append(returnOp->operand_begin(), returnOp->operand_end()); + mlir::OpBuilder launch_builder(launchOp); + launch_builder.create( + launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), new_operands); + // Launch does not have results, so we can just erase it. And the kernel + // also needs to go. + launchOp.erase(); + kernel.erase(); + }); + } +}; + +struct MapParallelLoopsPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); + } +}; + +struct FuseInnerParallelLoopsPass + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); + }); + } +}; + +struct ParallelLoopCollapsingToFirstDimPass + : public mlir::PassWrapper> { + void runOnOperation() override { + mlir::Operation* module = getOperation(); + + module->walk([&](mlir::scf::ParallelOp op) { + unsigned num_loops = op.getNumLoops(); + std::vector combinedLoops; + combinedLoops.reserve(num_loops); + for (unsigned i = 0; i < num_loops; ++i) { + combinedLoops.push_back(i); + } + mlir::collapseParallelLoops(op, {combinedLoops}); + }); + } +}; + +} // namespace + +std::unique_ptr createFusionOpRemoverPass() { + return absl::make_unique(); +} + +std::unique_ptr createStoreForwardingPass() { + return absl::make_unique(); +} + +std::unique_ptr createDeadTempBufferRemovalPass() { + return absl::make_unique(); +} + +std::unique_ptr +createMoveScalarComputationsIntoGpuLaunchPass() { + return absl::make_unique(); +} + +std::unique_ptr createRewriteKernelSignaturePass() { + return absl::make_unique(); +} + +std::unique_ptr createFuseInnerParallelLoopsPass() { + return absl::make_unique(); +} + +std::unique_ptr createMapParallelLoopsPass() { + return absl::make_unique(); +} + +std::unique_ptr> +createParallelLoopCollapsingToFirstDimPass() { + return absl::make_unique(); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.h b/tensorflow/compiler/xla/service/mlir_gpu/passes.h new file mode 100644 index 00000000000..e3840628a2e --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace xla { +namespace mlir_gpu { + +// TODO(herhut, pifon): Move these passes to MLIR Core. + +/// Replaces a FusionOp by the operations contained in its region. +std::unique_ptr createFusionOpRemoverPass(); + +/// Replaces a load that immediately follows a store to the same address with +/// the stored value. This needs generalization. +std::unique_ptr createStoreForwardingPass(); + +/// Removes temporary buffers that are only written to but never read from or +/// that are read but the read value is not used. Needs an analysis that proves +/// that loads and stores are side-effect free (in bounds, no aliasing, etc.). +std::unique_ptr createDeadTempBufferRemovalPass(); + +/// Moves scalar computations to the GPULaunchOp body. +std::unique_ptr +createMoveScalarComputationsIntoGpuLaunchPass(); + +/// Sorts the operands to the kernel for a deterministic order. First operands +/// that are defined by function arguments, followed by operands that are +/// returned from the function. This only works for simple functions without +/// control flow and can be used in cases where the kernel is extracted and used +/// independently of the host-side code. +std::unique_ptr createRewriteKernelSignaturePass(); + +/// We need to direct fusion to the inner loops. This cannot be done with +/// a passmanager alone ATM, as nested pass managers require operations to +/// be closed from above. +std::unique_ptr createFuseInnerParallelLoopsPass(); + +/// Greedily maps loops to GPU hardware dimensions. +std::unique_ptr createMapParallelLoopsPass(); + +/// Collapses all loop dimension into the first one. +std::unique_ptr> +createParallelLoopCollapsingToFirstDimPass(); + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index a21cec538d1..c5c2d081686 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -338,6 +339,21 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1, if (!ShapesCompatibleForFusion(instr1, instr2)) { return false; } + + // If both nodes are in-place operations and they use a common in-place + // operand, we can't fuse these two. + for (const auto& operand_and_output_index1 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr1)) { + const HloInstruction* operand = + instr1->operand(operand_and_output_index1.first.operand_number); + for (const auto& operand_and_output_index2 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr2)) { + if (operand == + instr2->operand(operand_and_output_index2.first.operand_number)) { + return false; + } + } + } return true; } diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index febbf9294b0..eb29fa89098 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -351,8 +351,7 @@ class AllOfPattern { // Returns a pattern that represents the conjunction of all input patterns. All // patterns need to match in order to have the AllOf pattern match. template -detail::AllOfPattern::type, Patterns...> AllOf( - const Patterns&... patterns) { +auto AllOf(const Patterns&... patterns) { return detail::AllOfPattern::type, Patterns...>(patterns...); } @@ -361,10 +360,8 @@ detail::AllOfPattern::type, Patterns...> AllOf( // // This transformation is necessary for good pretty-printing. template -detail::AllOfPattern::type, InnerPs..., - OuterPs...> -AllOf(const detail::AllOfPattern& inner_p, - const OuterPs&... outer_ps) { +auto AllOf(const detail::AllOfPattern& inner_p, + const OuterPs&... outer_ps) { // Invoke constructor of AllOfPattern. auto make_all_of = [](const InnerPs&... inner_ps, const OuterPs&... outer_ps) { @@ -453,10 +450,7 @@ template class LayoutPattern { private: template - auto AppendImpl(NewImpl new_impl) const - -> LayoutPattern(std::declval(), - std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl)); return LayoutPattern(std::move(new_allof), matched_layout_); @@ -495,14 +489,12 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr auto EqualTo(const ::xla::Layout* layout) const - -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + constexpr auto EqualTo(const ::xla::Layout* layout) const { return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr auto WithDenseFormat() const - -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + constexpr auto WithDenseFormat() const { return AppendImpl(LayoutPatternFormatImpl(DENSE)); } @@ -626,17 +618,14 @@ class AnyOfPattern { // patterns. The returned pattern matches from left to right, and stops on the // first match. template -detail::AnyOfPattern::type, Patterns...> AnyOf( - const Patterns&... patterns) { +auto AnyOf(const Patterns&... patterns) { return detail::AnyOfPattern::type, Patterns...>(patterns...); } // Creates a layout pattern that will capture the matched layout in the // argument. -inline constexpr detail::LayoutPattern -Layout(const ::xla::Layout** matched_layout = nullptr) { +inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) { return detail::LayoutPattern( detail::LayoutPatternBaseImpl(), matched_layout); @@ -644,9 +633,7 @@ Layout(const ::xla::Layout** matched_layout = nullptr) { // Creates a layout pattern that will capture the matched layout in the // argument. -inline constexpr detail::LayoutPattern<::xla::Layout, - detail::LayoutPatternBaseImpl> -Layout(::xla::Layout** matched_layout) { +inline constexpr auto Layout(::xla::Layout** matched_layout) { return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( detail::LayoutPatternBaseImpl(), matched_layout); } @@ -939,10 +926,7 @@ template class ShapePattern { private: template - auto AppendImpl(NewImpl new_impl) const - -> ShapePattern(std::declval(), - std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl)); return ShapePattern(std::move(new_all_of), matched_shape_); @@ -988,80 +972,66 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr auto EqualTo(const ::xla::Shape* shape) const - -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + constexpr auto EqualTo(const ::xla::Shape* shape) const { return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr auto CompatibleTo(const ::xla::Shape* shape) const - -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + constexpr auto CompatibleTo(const ::xla::Shape* shape) const { return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr auto WithElementType(PrimitiveType element_type) const - -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + constexpr auto WithElementType(PrimitiveType element_type) const { return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr auto IsScalar() const - -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + constexpr auto IsScalar() const { return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr auto IsArray() const - -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + constexpr auto IsArray() const { return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr auto IsTuple() const - -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + constexpr auto IsTuple() const { return AppendImpl(ShapePatternIsTupleImpl()); } - constexpr auto IsEffectiveScalar() const - -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { + constexpr auto IsEffectiveScalar() const { return AppendImpl(ShapePatternEffectiveScalarImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr auto WithRank(int64 rank) const - -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + constexpr auto WithRank(int64 rank) const { return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template - auto WithLayout(const LayoutPattern& layout) const - -> decltype(this->AppendImpl( - ShapePatternLayoutImpl(layout))) { + auto WithLayout(const LayoutPattern& layout) const { return AppendImpl(ShapePatternLayoutImpl(layout)); } - constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const - -> decltype(this->WithLayout(Layout().EqualTo(layout))) { + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const { return WithLayout(Layout().EqualTo(layout)); } - constexpr auto IsDenseArray() const - -> decltype(this->WithLayout(Layout().WithDenseFormat())) { + constexpr auto IsDenseArray() const { return WithLayout(Layout().WithDenseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template - auto WithSubshape(ShapeIndexView index, - const ShapePattern& subshape) - const -> decltype(this->AppendImpl( - ShapePatternSubshapeImpl(index, - subshape))) { + auto WithSubshape( + ShapeIndexView index, + const ShapePattern& subshape) const { return AppendImpl( ShapePatternSubshapeImpl(index, subshape)); } @@ -1101,17 +1071,13 @@ class ShapePattern { } // namespace detail // Creates a shape pattern that will capture the matched layout in the argument. -inline constexpr detail::ShapePattern -Shape(const ::xla::Shape** matched_shape = nullptr) { +inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) { return detail::ShapePattern( detail::ShapePatternBaseImpl(), matched_shape); } // Creates a shape pattern that will capture the matched layout in the argument. -inline constexpr detail::ShapePattern<::xla::Shape, - detail::ShapePatternBaseImpl> -Shape(::xla::Shape** matched_shape) { +inline constexpr auto Shape(::xla::Shape** matched_shape) { return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( detail::ShapePatternBaseImpl(), matched_shape); } @@ -1797,9 +1763,7 @@ template class HloInstructionPattern { private: template - auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< - HloInstructionType, decltype(AllOf<::xla::HloInstruction>( - std::declval(), std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl)); return HloInstructionPattern( std::move(new_allof), matched_inst_); @@ -1837,51 +1801,38 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - auto WithName(absl::string_view name) const - -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + auto WithName(absl::string_view name) const { return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - auto WithOpcode(HloOpcode opcode) const - -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, - false))) { + auto WithOpcode(HloOpcode opcode) const { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only the custom call with a given target. - auto WithCustomCallTarget(absl::string_view custom_call_target) const - -> decltype(this->AppendImpl( - HloInstructionCustomCallTargetImpl(custom_call_target))) { + auto WithCustomCallTarget(absl::string_view custom_call_target) const { return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target)); } - auto WithNumOperands(int64 num_operands) const -> decltype( - this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) { + auto WithNumOperands(int64 num_operands) const { return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - auto WithoutOpcode(HloOpcode opcode) const - -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, - true))) { + auto WithoutOpcode(HloOpcode opcode) const { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } - constexpr auto Is(const HloInstruction* instr) const - -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { + constexpr auto Is(const HloInstruction* instr) const { return AppendImpl(HloInstructionIsImpl(instr)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr auto IsConstant() const - -> decltype(this->WithOpcode(HloOpcode::kConstant)) { - return WithOpcode(HloOpcode::kConstant); - } + constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); } - constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( - HloConstantScalarImpl(/*match_effective_scalar=*/false))) { + constexpr auto IsConstantScalar() const { return AppendImpl( HloConstantScalarImpl(/*match_effective_scalar=*/false)); } @@ -1889,39 +1840,32 @@ class HloInstructionPattern { // This does not check that T has the same type as the instruction, so e.g. // IsConstantScalar(1.0) may match a constant of shape int32[]. template - constexpr auto IsConstantScalar(const ScalarTy& val) const - -> decltype(this->AppendImpl(HloConstantScalarImpl( - val, /*match_effective_scalar=*/false))) { + constexpr auto IsConstantScalar(const ScalarTy& val) const { return AppendImpl( HloConstantScalarImpl(val, /*match_effective_scalar=*/false)); } - constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( - HloConstantScalarImpl(/*match_effective_scalar=*/true))) { + constexpr auto IsConstantEffectiveScalar() const { return AppendImpl( HloConstantScalarImpl(/*match_effective_scalar=*/true)); } template - constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const - -> decltype(this->AppendImpl(HloConstantScalarImpl( - val, /*match_effective_scalar=*/true))) { + constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const { return AppendImpl( HloConstantScalarImpl(val, /*match_effective_scalar=*/true)); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr auto IsNonConstant() const - -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { + constexpr auto IsNonConstant() const { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template - constexpr auto WithShape(const ShapePattern& shape) - const -> decltype(this->AppendImpl( - HloInstructionPatternShapeImpl(shape))) { + constexpr auto WithShape( + const ShapePattern& shape) const { return AppendImpl( HloInstructionPatternShapeImpl(shape)); } @@ -1929,16 +1873,14 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const - -> decltype(this->WithShape(Shape().EqualTo(shape))) { + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const { return WithShape(Shape().EqualTo(shape)); } // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const - -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const { return WithShape(Shape().CompatibleTo(shape)); } @@ -1947,10 +1889,7 @@ class HloInstructionPattern { template constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern& operand) const - -> decltype(this->AppendImpl( - HloInstructionPatternOperandImpl( - operand_index, operand))) { + const HloInstructionPattern& operand) const { return AppendImpl( HloInstructionPatternOperandImpl( operand_index, operand)); @@ -1960,11 +1899,7 @@ class HloInstructionPattern { typename OperandImpl2> constexpr auto WithBinaryOperandsAnyOrder( const HloInstructionPattern& op1, - const HloInstructionPattern& op2) const - -> decltype(this->AppendImpl( - HloInstructionPatternBinaryOperandsAnyOrderImpl< - OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, - op2))) { + const HloInstructionPattern& op2) const { return AppendImpl( HloInstructionPatternBinaryOperandsAnyOrderImpl< OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); @@ -1972,46 +1907,39 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const - -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const { return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( - this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + constexpr auto WithTupleIndex(int64 tuple_index) const { return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } // Modifies the pattern to match only if the instruction is a parameter // with the given parameter number. - constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( - this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { + constexpr auto WithParameterNum(int64 parameter_num) const { return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } // Modifies the pattern to match if the instruction is used exactly once. // Does not match if the instruction is used twice by the same user (e.g. // multiply(x,x)). - constexpr auto WithOneUse() const - -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + constexpr auto WithOneUse() const { return AppendImpl(HloInstructionPatternOneUseImpl()); } // Modifies the pattern to match if the instruction is used by exactly one // other instruction. Will match if the instruction is used twice, so long as // it's by the same user (e.g. multiply(x,x)). - constexpr auto WithOneUser() const - -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + constexpr auto WithOneUser() const { return AppendImpl(HloInstructionPatternOneUserImpl()); } // Modifies the pattern to match only if the instruction has the given // comparison direction. - auto WithComparisonDirection(ComparisonDirection direction) const - -> decltype(this->AppendImpl( - HloInstructionPatternComparisonDirectionImpl(direction))) { + auto WithComparisonDirection(ComparisonDirection direction) const { return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); } @@ -2028,9 +1956,7 @@ class HloInstructionPattern { // Creates an instruction pattern that will capture the matched instruction in // the argument. -inline constexpr detail::HloInstructionPattern< - const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> -Op(const ::xla::HloInstruction** matched_inst = nullptr) { +inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) { return detail::HloInstructionPattern( detail::HloInstructionPatternBaseImpl(), matched_inst); @@ -2038,24 +1964,19 @@ Op(const ::xla::HloInstruction** matched_inst = nullptr) { // Creates an instruction pattern that will capture the matched instruction in // the argument. -inline constexpr detail::HloInstructionPattern< - ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> -Op(::xla::HloInstruction** matched_inst) { +inline constexpr auto Op(::xla::HloInstruction** matched_inst) { return detail::HloInstructionPattern<::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>( detail::HloInstructionPatternBaseImpl(), matched_inst); } // Helpers for nullary instructions. -#define XLA_NULLOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst) \ - ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ - return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ +#define XLA_NULLOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) @@ -2064,28 +1985,21 @@ XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. -#define XLA_UNOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Arg&& arg)->decltype( \ - Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg))) { \ - return Op() \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg)); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg)); \ +#define XLA_UNOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(Arg&& arg) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ } XLA_UNOP_PATTERN(Abs) XLA_UNOP_PATTERN(RoundNearestAfz) @@ -2124,55 +2038,40 @@ XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN // Helpers for binary instructions. -#define XLA_BINOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs))) { \ - return Op() \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)); \ +#define XLA_BINOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ } -#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ - XLA_BINOP_PATTERN(NAME) \ - \ - template \ - inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ - Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs)); \ - } \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs)); \ +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) @@ -2202,16 +2101,10 @@ XLA_BINOP_PATTERN(ShiftRightLogical) // Helpers for ternary instructions. #define XLA_TERNOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ \ template \ - inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ - ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg0)) \ - .WithOperand(1, std::forward(arg1)) \ - .WithOperand(2, std::forward(arg2))) { \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \ return Op() \ .WithOpcode(HloOpcode::k##NAME) \ .WithOperand(0, std::forward(arg0)) \ @@ -2222,12 +2115,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) template \ inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ - Arg1&& arg1, Arg2&& arg2) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg0)) \ - .WithOperand(1, std::forward(arg1)) \ - .WithOperand(2, std::forward(arg2))) { \ + Arg1&& arg1, Arg2&& arg2) { \ return Op(matched_inst) \ .WithOpcode(HloOpcode::k##NAME) \ .WithOperand(0, std::forward(arg0)) \ @@ -2241,17 +2129,13 @@ XLA_TERNOP_PATTERN(Select); namespace detail { template -inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) - -> decltype(m.WithOperand(operand_num, std::forward(first_arg))) { +inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) { return m.WithOperand(operand_num, std::forward(first_arg)); } template inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, - Args&&... args) - -> decltype(WithOperands(m.WithOperand(operand_num, - std::forward(first_arg)), - operand_num + 1, std::forward(args)...)) { + Args&&... args) { return WithOperands( m.WithOperand(operand_num, std::forward(first_arg)), operand_num + 1, std::forward(args)...); @@ -2259,26 +2143,17 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, } // namespace detail #define XLA_VARIADIC_OP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ \ template \ - inline auto NAME(Args&&... args) \ - ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithNumOperands(sizeof...(Args)), \ - 0, std::forward(args)...)) { \ + inline auto NAME(Args&&... args) { \ return detail::WithOperands( \ Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \ /*operand_num=*/0, std::forward(args)...); \ } \ \ template \ - inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \ - ->decltype(detail::WithOperands(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithNumOperands(sizeof...(Args)), \ - 0, std::forward(args)...)) { \ + inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \ return detail::WithOperands(Op(matched_inst) \ .WithOpcode(HloOpcode::k##NAME) \ .WithNumOperands(sizeof...(Args)), \ @@ -2299,63 +2174,46 @@ XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for comparison instructions. -#define XLA_COMPARE_PATTERN(NAME) \ - inline auto NAME()->decltype( \ - Op().WithOpcode(HloOpcode::kCompare) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op() \ - .WithOpcode(HloOpcode::kCompare) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op() \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ +#define XLA_COMPARE_PATTERN(NAME) \ + inline auto NAME() { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ } -#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ - XLA_COMPARE_PATTERN(NAME) \ - \ - template \ - inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ - Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs)); \ - } \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs)); \ +#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ + XLA_COMPARE_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); @@ -2366,23 +2224,17 @@ XLA_COMPARE_PATTERN(Le); XLA_COMPARE_PATTERN(Lt); // Helpers for matching non-constant instructions. -inline auto NonConstant() -> decltype(Op().IsNonConstant()) { - return Op().IsNonConstant(); -} +inline auto NonConstant() { return Op().IsNonConstant(); } template -inline auto NonConstant(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsNonConstant()) { +inline auto NonConstant(HloInstructionType** matched_inst) { return Op(matched_inst).IsNonConstant(); } // Add overloads for GetTupleElement which take a int64 specifying which tuple // element is selected. template -inline auto GetTupleElement(Arg&& arg, int64 tuple_index) - -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) - .WithOperand(0, std::forward(arg)) - .WithTupleIndex(tuple_index)) { +inline auto GetTupleElement(Arg&& arg, int64 tuple_index) { return Op() .WithOpcode(HloOpcode::kGetTupleElement) .WithOperand(0, std::forward(arg)) @@ -2391,11 +2243,7 @@ inline auto GetTupleElement(Arg&& arg, int64 tuple_index) template inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, - int64 tuple_index) - -> decltype(Op(matched_inst) - .WithOpcode(HloOpcode::kGetTupleElement) - .WithOperand(0, std::forward(arg)) - .WithTupleIndex(tuple_index)) { + int64 tuple_index) { return Op(matched_inst) .WithOpcode(HloOpcode::kGetTupleElement) .WithOperand(0, std::forward(arg)) @@ -2404,62 +2252,50 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, // Add overloads for Parameter which take an int64 specifying the parameter // number. -inline auto Parameter(int64 parameter_num) -> decltype( - Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { +inline auto Parameter(int64 parameter_num) { return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); } template -inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) - -> decltype(Op(matched_inst) - .WithOpcode(HloOpcode::kParameter) - .WithParameterNum(parameter_num)) { +inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) { return Op(matched_inst) .WithOpcode(HloOpcode::kParameter) .WithParameterNum(parameter_num); } -inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { - return Op().IsConstantScalar(); -} +inline auto ConstantScalar() { return Op().IsConstantScalar(); } template -inline auto ConstantScalar(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsConstantScalar()) { +inline auto ConstantScalar(HloInstructionType** matched_inst) { return Op(matched_inst).IsConstantScalar(); } template -inline auto ConstantScalar(ScalarTy val) - -> decltype(Op().IsConstantScalar(val)) { +inline auto ConstantScalar(ScalarTy val) { return Op().IsConstantScalar(val); } template -inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) - -> decltype(Op(matched_inst).IsConstantScalar(val)) { +inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) { return Op(matched_inst).IsConstantScalar(val); } -inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { +inline auto ConstantEffectiveScalar() { return Op().IsConstantEffectiveScalar(); } template -inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsConstantScalar()) { +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) { return Op(matched_inst).IsConstantEffectiveScalar(); } template -inline auto ConstantEffectiveScalar(ScalarTy val) - -> decltype(Op().IsConstantEffectiveScalar(val)) { +inline auto ConstantEffectiveScalar(ScalarTy val) { return Op().IsConstantEffectiveScalar(val); } template inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, - ScalarTy val) - -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { + ScalarTy val) { return Op(matched_inst).IsConstantEffectiveScalar(val); } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index e3a3feb8640..bd99f920ea0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -325,6 +325,22 @@ static StatusOr> ScatterLoopBody( {updated_operand, scatter_indices, updates}}; } +static int64 ScatterTripCount(HloInstruction* scatter) { + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + HloInstruction* scatter_indices = scatter->mutable_operand(1); + const Shape& scatter_indices_shape = scatter_indices->shape(); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + return scatter_loop_trip_count; +} + // High Level Algorithm. // // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where @@ -342,7 +358,7 @@ static StatusOr> ScatterLoopBody( // from c. and d. using the update_computation of scatter. // f. Write the updated value of the slice into the operand tensor. -StatusOr ScatterExpander::ExpandScatter( +StatusOr ScatterExpander::ExpandInstruction( HloInstruction* scatter) { HloInstruction* operand = scatter->mutable_operand(0); HloInstruction* scatter_indices = scatter->mutable_operand(1); @@ -358,13 +374,7 @@ StatusOr ScatterExpander::ExpandScatter( // Compute the trip count for the while loop to be used for scatter. This // should be the number of indices we should scatter into the operand. - const Shape& scatter_indices_shape = scatter_indices->shape(); - int64 scatter_loop_trip_count = 1; - for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); - } - } + int64 scatter_loop_trip_count = ScatterTripCount(scatter); if (!IsInt32(scatter_loop_trip_count)) { return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " @@ -408,23 +418,9 @@ StatusOr ScatterExpander::ExpandScatter( return scatter_loop_result.front(); } -StatusOr ScatterExpander::Run(HloModule* module) { - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - for (HloInstruction* instr : computation->instructions()) { - if (instr->opcode() == HloOpcode::kScatter) { - scatter_instrs.push_back(instr); - } - } - } - - for (auto instr : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); - TF_RETURN_IF_ERROR( - instr->parent()->ReplaceInstruction(instr, expanded_root)); - } - - return !scatter_instrs.empty(); +bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kScatter && + (mode_ == kEliminateAllScatters || ScatterTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 533af060bc9..aa59e7ec3b0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -16,17 +16,43 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { -class ScatterExpander : public HloModulePass { +// This pass rewrites scatter operations into (roughly) while loops of +// dynamic-update-slices. +// +// This pass can be used in two ways: +// +// - kEliminateAllScatters: For backends that don't support scatter, this pass +// can convert every scatter into a loop. +// +// - kEliminateSimpleScatters: For backends that *do* support scatter, this +// pass can strength-reduce "simple" scatters -- specifically, scatters that +// can be represented without a loop -- to dynamic-update-slices. +// +// Note that even in kEliminateSimpleScatters mode, this pass may still expand a +// scatter into a loop (with a trip-count of 1). It's up to other +// simplification passes to remove the loop. +class ScatterExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllScatters, + kEliminateSimpleScatters, + }; + + explicit ScatterExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "scatter_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandScatter(HloInstruction* scatter); + bool InstructionMatchesPattern(HloInstruction* inst) override; + + StatusOr ExpandInstruction(HloInstruction* scatter) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander_test.cc b/tensorflow/compiler/xla/service/scatter_expander_test.cc index 3852b82c1ef..9f4cc5406d8 100644 --- a/tensorflow/compiler/xla/service/scatter_expander_test.cc +++ b/tensorflow/compiler/xla/service/scatter_expander_test.cc @@ -57,11 +57,79 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) { ParseAndReturnVerifiedModule(kModuleStr)); // The HLO parser changes all no layout shapes from the input to have a - // default layout, clear the layout of the scatter operand for testing. + // default layout. Clear the layout of the scatter operand for testing. HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); scatter_operand->mutable_shape()->clear_layout(); - ScatterExpander scatter_expander; + ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_TRUE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=scatter_computation, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersRewritesTrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[5] iota(), iota_dimension=0 + indices = s32[1] parameter(0) + update = s32[] constant(0) + ROOT scatter = s32[5]{0} scatter(operand, indices, update), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=0, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&scatter_expander, module.get())); EXPECT_TRUE(result); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8e39e32e4c3..a96c9c34260 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2825,6 +2825,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } +/* static */ StatusOr ShapeInference::InferDynamicReshapeShape( + const Shape& operand, absl::Span dim_size_shapes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + if (new_size_bounds.size() != dims_are_dynamic.size()) { + return InvalidArgument( + "DynamicReshape has to have the same number of elements in new_sizes " + "(%d) and dims_are_dynamic (%d)", + new_size_bounds.size(), dims_are_dynamic.size()); + } + + for (const Shape* dim_size_shape : dim_size_shapes) { + if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) { + return InvalidArgument( + "DynamicReshape's dim size has to be scalar S32, got (%s): ", + dim_size_shape->ToString()); + } + } + + Shape inferred_shape = ShapeUtil::MakeShape( + operand.element_type(), new_size_bounds, dims_are_dynamic); + if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { + return InvalidArgument( + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), + ShapeUtil::ElementsIn(inferred_shape), + ShapeUtil::HumanString(inferred_shape)); + } + return inferred_shape; +} + /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d47d96ab52d..f03e4e5fa98 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -241,6 +241,15 @@ class ShapeInference { absl::Span new_sizes, int64 inferred_dimension); + // Infers the shape produced by a dynamic reshape operation from the element + // type of its operand and the new dimension sizes specified. The result shape + // will have dynamic dimensions as specific in `dim_is_dynamic` and bound + // `new_size_bounds`. + static StatusOr InferDynamicReshapeShape( + const Shape& operand, absl::Span dim_size_shapes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 5d85fb5189c..6524973a08e 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -91,9 +91,7 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { return is_better; } if (!rhs.IsTileMaximal()) { - // If we already have a non-tile-maximal sharding then we can't improve - // that. - return false; + return lhs.NumTiles() > rhs.NumTiles(); } else if (!rhs.IsReplicated()) { // If we are not replicated then only tiled (not tile maximal) shardings // can improve us. @@ -122,22 +120,158 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a, return IsShardingMoreSpecific(a, b) ? a : b; } +// Tries to refine `to_merge` by combining with `old`. Returns if the final +// `to_merge` is more specific than `old`. May combine partial sharding in +// addition to MergeForMoreSpecificSharding(). +bool MergeSharding(const HloSharding& old, HloSharding* to_merge, + bool may_combine_partial_sharding) { + if (old.IsTuple()) { + CHECK(to_merge->IsTuple()); + bool changed = false; + for (int64 i = 0; i < old.tuple_elements().size(); ++i) { + changed |= + MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], + may_combine_partial_sharding); + } + return changed; + } + if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || + !to_merge->ReplicateOnLastTileDim() || + old.tile_assignment().num_elements() != + to_merge->tile_assignment().num_elements()) { + return IsShardingMoreSpecific(*to_merge, old); + } + // Combine the tile dimension sizes from new and old. + int64 num_devices = old.tile_assignment().num_elements(); + std::vector new_tile_dims; + bool compatible = true; + new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); + for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { + int64 new_dim = to_merge->tile_assignment().dim(i); + int64 old_dim = old.tile_assignment().dim(i); + if (new_dim == 1) { + new_tile_dims.push_back(old_dim); + } else if (old_dim == 1) { + new_tile_dims.push_back(new_dim); + } else if (new_dim == old_dim) { + new_tile_dims.push_back(new_dim); + } else { + compatible = false; + break; + } + } + int64 replication = num_devices / Product(new_tile_dims); + if (!compatible || num_devices % Product(new_tile_dims) != 0 || + replication >= old.tile_assignment().dimensions().back()) { + return IsShardingMoreSpecific(*to_merge, old); + } + new_tile_dims.push_back(replication); + Array new_tile(new_tile_dims); + // Maps from replication group ID to sorted members. + absl::flat_hash_map> old_group_members; + absl::flat_hash_map> new_group_members; + auto get_group_index = [&](absl::Span tile_indices, + const HloSharding& sharding) { + int64 group_id = 0; + for (int64 i = 0; i < tile_indices.size() - 1; ++i) { + group_id *= to_merge->tile_assignment().dim(i); + group_id += tile_indices[i]; + } + return group_id; + }; + old.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + old_group_members[get_group_index(indices, old)].insert(device); + }); + to_merge->tile_assignment().Each( + [&](absl::Span indices, int64 device) { + new_group_members[get_group_index(indices, *to_merge)].insert(device); + }); + // Try to find the intersection of old and new replication groups, in + // order to determine the merged tile assignment. + new_tile.Each([&](absl::Span indices, int64* device) { + if (!compatible) { + return; + } + std::vector old_index(indices.begin(), indices.end()); + std::vector new_index = old_index; + for (int64 i = 0; i < indices.size() - 1; ++i) { + if (old.tile_assignment().dim(i) == 1) { + old_index[i] = 0; + } + if (to_merge->tile_assignment().dim(i) == 1) { + new_index[i] = 0; + } + } + int64 old_group_id = get_group_index(old_index, old); + int64 new_group_id = get_group_index(new_index, *to_merge); + if (old_group_members[old_group_id].empty() || + new_group_members[new_group_id].empty() || + *old_group_members[old_group_id].begin() != + *new_group_members[new_group_id].begin()) { + compatible = false; + return; + } + *device = *old_group_members[old_group_id].begin(); + old_group_members[old_group_id].erase(*device); + new_group_members[new_group_id].erase(*device); + }); + if (compatible) { + if (replication == 1) { + new_tile_dims.pop_back(); + new_tile.Reshape(new_tile_dims); + *to_merge = HloSharding::Tile(new_tile); + } else { + *to_merge = HloSharding::PartialTile(new_tile); + } + return true; + } + return IsShardingMoreSpecific(*to_merge, old); +} + // Updates the sharding of the specified instruction with the specified sharding // if it is better than the current one and returns true if a new sharding have -// been applied. -bool MaybeImproveInstructionSharding(const HloSharding& sharding, - HloInstruction* instruction) { +// been applied. If may_combine_partial_sharding is true, this may combine the +// new and existing sharding if they are both partial tiling partial +// replication. +bool MaybeImproveInstructionSharding(HloSharding sharding, + HloInstruction* instruction, + bool may_combine_partial_sharding) { // We don't want to propagate tile maximal shardings. if (!IsSpatiallyPartitioned(sharding)) { return false; } // Any sharding is better then no sharding. if (!instruction->has_sharding()) { - instruction->set_sharding(sharding); + instruction->set_sharding(std::move(sharding)); return true; } - if (IsShardingMoreSpecific(sharding, instruction->sharding())) { - instruction->set_sharding(sharding); + int64 sharding_tiles = sharding.NumTiles(); + if (MergeSharding(instruction->sharding(), &sharding, + may_combine_partial_sharding)) { + // Override existing tiled sharding only when the new sharding is compatible + // with the existing one. This avoids unexpected resharding when `sharding` + // just has more tiles than existing sharding but they are not mergeable. + if (instruction->shape().IsArray() && + !instruction->sharding().IsTileMaximal() && + sharding.NumTiles() == sharding_tiles) { + std::vector diff_dims; + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + if (instruction->sharding().tile_assignment().dim(i) == + sharding.tile_assignment().dim(i)) { + continue; + } + if (instruction->sharding().tile_assignment().dim(i) != 1) { + return false; + } + diff_dims.push_back(i); + } + if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, diff_dims) != instruction->sharding()) { + return false; + } + } + instruction->set_sharding(std::move(sharding)); return true; } return false; @@ -277,6 +411,7 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kDot: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kDynamicReshape: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: @@ -361,12 +496,114 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction, } } +bool InferDotShardingFromOperands( + HloInstruction* instruction, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool may_combine_partial_sharding) { + auto from_operand = [&](int64 operand_index) { + auto operand = instruction->operand(operand_index); + const HloSharding& operand_sharding = operand->sharding(); + if (operand_sharding.IsTileMaximal()) { + return operand_sharding; + } + std::vector contracting_dims; + contracting_dims.reserve(dnums.contracting_dims.size()); + for (const auto& dim : dnums.contracting_dims) { + contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs); + } + // It's possible that some size-1 spatial dims of convolutions are parsed as + // non-contracting dims. We might have tiled dimensions on them. + for (const auto& dim : operand_index == 0 + ? dnums.rhs_non_contracting_dims + : dnums.lhs_non_contracting_dims) { + int64 d = operand_index == 0 ? dim.lhs : dim.rhs; + if (d > 0) { + contracting_dims.push_back(d); + } + } + auto replicate_contracting_dims = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand_sharding, contracting_dims); + std::vector out_dims_to_op_perm(instruction->shape().rank(), -1); + std::vector op_dims_to_output_perm(operand->shape().rank(), -1); + for (const auto& dim : dnums.batch_dims) { + out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + for (const auto& dim : operand_index == 0 + ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + return *hlo_sharding_util::TransposeShardingWithCollapsedDims( + replicate_contracting_dims, op_dims_to_output_perm, + out_dims_to_op_perm); + }; + bool changed = false; + int64 larger_operand = + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >= + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape()) + ? 0 + : 1; + if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) { + changed |= MaybeImproveInstructionSharding(from_operand(larger_operand), + instruction, + may_combine_partial_sharding); + } + if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) { + changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand), + instruction, + may_combine_partial_sharding); + } + return changed; +} + // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - bool aggressive_prop) { + int64 aggressiveness, + bool may_combine_partial_sharding) { + auto get_partitions_for_dims = + [&](const HloInstruction* inst, + absl::Span< + const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums> + dims, + int lhs_or_rhs) { + int64 partitions = 1; + if (!inst->has_sharding()) { + return partitions; + } + const auto& sharding = inst->sharding(); + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_or_rhs == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else { + CHECK_EQ(lhs_or_rhs, 1); + partitions *= sharding.tile_assignment().dim(dim.rhs); + } + } + return partitions; + }; + auto dot_dims = + dot_as_convolution_util::ParseConvolutionDimsInfo(instruction); + const int64 lhs_conv_spatial_partitions = get_partitions_for_dims( + instruction->operand(0), dot_dims.conv_spatial_dims, 0); + const int64 rhs_conv_spatial_partitions = get_partitions_for_dims( + instruction->operand(1), dot_dims.conv_spatial_dims, 1); + if (dot_dims.conv_spatial_dims.empty() || + (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && + instruction->batch_group_count() == 1 && + instruction->feature_group_count() == 1)) { + return InferDotShardingFromOperands(instruction, dot_dims, + may_combine_partial_sharding); + } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); auto get_tiled_sharding_based_on_lhs = [&] { CHECK(!lhs->sharding().IsTileMaximal()); std::vector output_to_lhs_indices(instruction->shape().rank()); @@ -381,103 +618,12 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, return hlo_sharding_util::TransposeSharding(lhs->sharding(), output_to_lhs_indices); }; - auto get_tiled_sharding_based_on_rhs = [&] { - CHECK(!rhs->sharding().IsTileMaximal()); - std::vector output_to_rhs_indices(instruction->shape().rank()); - output_to_rhs_indices[dnums.output_batch_dimension()] = - dnums.kernel_input_feature_dimension(); - output_to_rhs_indices[dnums.output_feature_dimension()] = - dnums.kernel_output_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - output_to_rhs_indices[dnums.output_spatial_dimensions(i)] = - dnums.kernel_spatial_dimensions(i); - } - return hlo_sharding_util::TransposeSharding(rhs->sharding(), - output_to_rhs_indices); - }; - if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution( - instruction)) { - // lhs_or_rhs: lhs is 0 and rhs is 1. Skips dimensions with size 1. - auto partitioned_only_along_non_trivial_dims = - [&](const HloSharding& sharding, - std::vector& dims, - int64 lhs_or_rhs) { - if (sharding.IsTileMaximal()) { - return false; - } - int64 partition_count = 1; - for (const auto& dim : dims) { - if (lhs_or_rhs == 0) { - if (lhs->shape().dimensions(dim.lhs) == 1) { - continue; - } - partition_count *= sharding.tile_assignment().dim(dim.lhs); - } else { - if (rhs->shape().dimensions(dim.rhs) == 1) { - continue; - } - CHECK_EQ(lhs_or_rhs, 1); - partition_count *= sharding.tile_assignment().dim(dim.rhs); - } - } - return partition_count == sharding.tile_assignment().num_elements(); - }; - // If LHS/RHS is partitioned only along the batch dimensions, propagate - // the sharding to the output, since batch dimensions are the easiest to - // partition. - if (IsSpatiallyPartitioned(lhs) && - partitioned_only_along_non_trivial_dims(lhs->sharding(), - dot_dims->batch_dims, 0)) { - return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(), - instruction); - } - if (IsSpatiallyPartitioned(rhs) && - partitioned_only_along_non_trivial_dims(rhs->sharding(), - dot_dims->batch_dims, 1)) { - return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_rhs(), - instruction); - } - if (aggressive_prop) { - // If LHS/RHS is partitioned only along the non-contracting - // dimensions, propagate the sharding to the output. - const bool can_propagate_from_lhs = - IsSpatiallyPartitioned(lhs) && - partitioned_only_along_non_trivial_dims( - lhs->sharding(), dot_dims->lhs_non_contracting_dims, 0); - const bool can_propagate_from_rhs = - IsSpatiallyPartitioned(rhs) && - partitioned_only_along_non_trivial_dims( - rhs->sharding(), dot_dims->rhs_non_contracting_dims, 1); - // If we can propagate from both operands, choose the larger one which - // should help us reduce communications. - if (can_propagate_from_lhs && can_propagate_from_rhs) { - if (Product(lhs->shape().dimensions()) >= - Product(rhs->shape().dimensions())) { - return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_lhs(), instruction); - } else { - return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_rhs(), instruction); - } - } - if (can_propagate_from_lhs) { - return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_lhs(), instruction); - } - if (can_propagate_from_rhs) { - return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_rhs(), instruction); - } - } - } - if (!IsSpatiallyPartitioned(lhs)) { return false; } if (lhs->sharding().IsReplicated()) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, may_combine_partial_sharding); } if (IsConvolutionKernelSmall(instruction)) { @@ -488,11 +634,28 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, return false; } return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(), - instruction); + instruction, + may_combine_partial_sharding); } // If the kernel is large (e.g backward convolution) then we only support // replicated output. - return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction); + return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction, + may_combine_partial_sharding); +} + +bool CanPropagateThroughAtAgressiveLevel(const HloInstruction& inst, + int64 aggressiveness) { + // At minimum agressiveness, only allow pass-through ops. + if (aggressiveness < 1 && !inst.IsElementwise() && + inst.opcode() != HloOpcode::kTranspose && + inst.opcode() != HloOpcode::kReshape) { + return false; + } + // Broadcast propagation should have at least aggressiveness 2. + if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) { + return false; + } + return true; } // Tries to update the sharding of the specified instruction based on its @@ -500,7 +663,11 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, // changed and false otherwise. bool InferShardingFromOperands(HloInstruction* instruction, const ComputationMap& computation_map, - bool is_spmd, bool aggressive_prop) { + bool is_spmd, int64 aggressiveness) { + if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) { + return false; + } + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { // If an array shaped HLO doesn't support spatial partitioning but at least // one of its operand is replicated then we make the HLO replicated as well. @@ -512,8 +679,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { return op->has_sharding() && op->sharding().IsReplicated(); })) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, may_combine_partial_sharding); } return false; } @@ -526,7 +693,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } HloSharding new_sharding = operand->sharding().GetSubSharding( operand->shape(), {instruction->tuple_index()}); - return MaybeImproveInstructionSharding(new_sharding, instruction); + return MaybeImproveInstructionSharding( + std::move(new_sharding), instruction, may_combine_partial_sharding); } case HloOpcode::kTuple: { if (absl::c_none_of(instruction->operands(), @@ -591,60 +759,60 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (!IsSpatiallyPartitioned(operand)) { continue; } - auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { + auto get_maybe_tuple_sharding = [&](HloSharding sharding) { if (instruction->operand_count() == 2) { return sharding; } std::vector tuple(instruction->operand_count() / 2, - sharding); + std::move(sharding)); return HloSharding::Tuple(instruction->shape(), tuple); }; - if (operand->sharding().IsReplicated()) { + if (operand->sharding().IsReplicated() || + (!is_spmd && + absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { + return operand->sharding().tile_assignment().dim(dim) > 1; + }))) { + // We are reducing along one of the sharded dimensions. We only + // support this in SPMD. changed |= MaybeImproveInstructionSharding( - get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, + may_combine_partial_sharding); continue; } - if (absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { - return operand->sharding().tile_assignment().dim(dim) > 1; - })) { - // We are reducing along one of the sharded dimensions. We don't - // support tiled sharding in this case. + auto after_partial_replication = + operand->sharding().IsReplicated() + ? operand->sharding() + : hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand->sharding(), instruction->dimensions()); + if (after_partial_replication.IsReplicated()) { changed |= MaybeImproveInstructionSharding( - get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); - } else { - // We are reducing along some of the non-sharded dimensions. The - // result sharding should be the same as the operand sharding with the - // reduction dimensions removed as they are removed from the result - // shape. - std::vector target_tile_assignment_dimensions; - const auto& dimensions = instruction->dimensions(); - for (int64 i = 0; i < operand->shape().rank(); ++i) { - if (absl::c_find(dimensions, i) == dimensions.end()) { - target_tile_assignment_dimensions.push_back( - operand->sharding().tile_assignment().dim(i)); - } - } - Array new_tile_assignment = - operand->sharding().tile_assignment(); - new_tile_assignment.Reshape(target_tile_assignment_dimensions); - // Use the same sharding for all tuple elements, because they are part - // of the same reduce instruction. - HloSharding new_sharding = - get_maybe_tuple_sharding(HloSharding::Tile(new_tile_assignment)); - changed |= MaybeImproveInstructionSharding(new_sharding, instruction); + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, + may_combine_partial_sharding); + continue; } + // Use the same sharding for all tuple elements, because they are part + // of the same reduce instruction. + HloSharding new_sharding = + get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions( + after_partial_replication, instruction->dimensions())); + changed |= MaybeImproveInstructionSharding( + std::move(new_sharding), instruction, may_combine_partial_sharding); } return changed; } case HloOpcode::kBroadcast: { - const HloInstruction* op = instruction->operand(0); - if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { + // Make forward propagation through broadcast low priority to avoid + // resharding after broadcast. + if (aggressiveness < 3) { return false; } - // Heuristic: If an operand is more than 8 times fewer elements than its - // output, do not propagate sharding. - if (ShapeUtil::ElementsIn(instruction->shape()) > - 8 * ShapeUtil::ElementsIn(op->shape())) { + // Do not override existing tile sharding. This is likely from users. + if (IsSpatiallyPartitioned(instruction) && + !instruction->sharding().IsTileMaximal()) { + return false; + } + const HloInstruction* op = instruction->operand(0); + if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { return false; } // The output will be tiled along the broadcasted dimension the same way @@ -662,13 +830,22 @@ bool InferShardingFromOperands(HloInstruction* instruction, op->sharding().tile_assignment().dim(source_dim)); } } + if (op->sharding().ReplicateOnLastTileDim()) { + target_tile_assignment_dimensions.push_back( + op->sharding().tile_assignment().dimensions().back()); + } Array new_tile_assignment = op->sharding().tile_assignment(); new_tile_assignment.Reshape(target_tile_assignment_dimensions); - HloSharding new_sharding = HloSharding::Tile(new_tile_assignment); - return MaybeImproveInstructionSharding(new_sharding, instruction); + HloSharding new_sharding = + op->sharding().ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); + return MaybeImproveInstructionSharding( + std::move(new_sharding), instruction, may_combine_partial_sharding); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressive_prop); + return InferConvolutionShardingFromOperands(instruction, aggressiveness, + may_combine_partial_sharding); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!IsSpatiallyPartitioned(input)) { @@ -676,7 +853,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } HloSharding sharding = hlo_sharding_util::TransposeSharding( input->sharding(), instruction->dimensions()); - return MaybeImproveInstructionSharding(sharding, instruction); + return MaybeImproveInstructionSharding(std::move(sharding), instruction, + may_combine_partial_sharding); } case HloOpcode::kReduceWindow: { const HloInstruction* lhs = instruction->operand(0); @@ -694,7 +872,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + return MaybeImproveInstructionSharding(lhs->sharding(), instruction, + may_combine_partial_sharding); } case HloOpcode::kSelectAndScatter: { // Shard according to first operand, as output keeps the same shape. @@ -713,7 +892,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + return MaybeImproveInstructionSharding(lhs->sharding(), instruction, + may_combine_partial_sharding); } case HloOpcode::kReshape: { if (!IsSpatiallyPartitioned(instruction->operand(0))) { @@ -724,8 +904,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->shape(), instruction->shape(), instruction->operand(0)->sharding()); if (new_sharding.has_value()) { - return MaybeImproveInstructionSharding(new_sharding.value(), - instruction); + return MaybeImproveInstructionSharding(std::move(*new_sharding), + instruction, + may_combine_partial_sharding); } return false; } @@ -736,83 +917,13 @@ bool InferShardingFromOperands(HloInstruction* instruction, return MaybeImproveInstructionSharding( hlo_sharding_util::ReverseSharding( instruction->operand(0)->sharding(), instruction->dimensions()), - instruction); + instruction, may_combine_partial_sharding); } case HloOpcode::kDot: { - auto& dot_dim_numbs = instruction->dot_dimension_numbers(); - // Batch dimensions are the same for lhs and rhs on dot operations. - int64 num_batch_dims = dot_dim_numbs.lhs_batch_dimensions_size(); - std::vector contracting_dims(2); - contracting_dims[0] = dot_dim_numbs.lhs_contracting_dimensions(0); - contracting_dims[1] = dot_dim_numbs.rhs_contracting_dimensions(0); - std::vector ops_sharding(2, nullptr); - for (int64 op_num = 0; op_num < 2; ++op_num) { - const HloInstruction* op = instruction->operand(op_num); - if (IsSpatiallyPartitioned(op)) { - ops_sharding[op_num] = &op->sharding(); - } - } - if (ops_sharding[0] == nullptr && ops_sharding[1] == nullptr) { - return false; - } - - // Select representative operand. - int64 representative_op = -1; - if (ops_sharding[0] == nullptr) { - representative_op = 1; - } else if (ops_sharding[1] == nullptr) { - representative_op = 0; - } else if (ops_sharding[0]->IsReplicated() && - ops_sharding[1]->IsReplicated()) { - // Both replicated -> replicate - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); - } else if (!ops_sharding[0]->IsReplicated() && - !ops_sharding[1]->IsReplicated()) { - // Both tile sharded. The dot spatial partitioning implementation - // replicates the operand corresponding to the non-tiled dimension: - // dot(lhs, rhs), sharding={devices=[1, ..., n, 1]} replicates rhs - // dot(lhs, rhs), sharding={devices=[1, ..., 1, n]} replicates lhs - // so set sharding in order to replicate the smaller of lhs and rhs - representative_op = - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape()) - ? 1 - : 0; - } else { - // One is replicated and the other is tiled - pick the tiled one. - representative_op = ops_sharding[0]->IsReplicated() ? 1 : 0; - } - - if (ops_sharding[representative_op]->IsReplicated()) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); - } else { - // Tile-shard instruction according to representative op. - auto sharding = *ops_sharding[representative_op]; - if (instruction->shape().dimensions_size() != - sharding.tile_assignment().num_dimensions()) { - // It is necessarily the case of a matrix x vector, with - // representative_op being the matrix, because the vector op has the - // same shape as instruction. - CHECK_EQ(sharding.tile_assignment().num_dimensions(), - instruction->shape().dimensions_size() + 1); - // Reshape sharding so that last dimension is 1, and then remove - // last dimension. - std::vector non_batch_dims( - sharding.tile_assignment().num_dimensions() - num_batch_dims); - absl::c_iota(non_batch_dims, num_batch_dims); - sharding = hlo_sharding_util::ReshapeToTileDimension( - sharding, num_batch_dims, non_batch_dims); - auto tile_assignment = sharding.tile_assignment(); - auto dimensions = tile_assignment.dimensions(); - CHECK_EQ(dimensions.back(), 1); - dimensions.pop_back(); - tile_assignment.Reshape(dimensions); - sharding = HloSharding::Tile(tile_assignment); - } - return MaybeImproveInstructionSharding(sharding, instruction); - } + const auto& dnums = + dot_as_convolution_util::ParseDotGeneralFromDot(instruction); + return InferDotShardingFromOperands(instruction, dnums, + may_combine_partial_sharding); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); @@ -826,7 +937,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (parent->called_computations()[i - 1] == instruction->parent()) { if (parent->operand(i)->has_sharding()) { return MaybeImproveInstructionSharding( - parent->operand(i)->sharding(), instruction); + parent->operand(i)->sharding(), instruction, + may_combine_partial_sharding); } return false; } @@ -853,15 +965,15 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (instruction->shape().IsTuple()) { return MaybeImproveInstructionSharding( HloSharding::SingleTuple(instruction->shape(), operand->sharding()), - instruction); + instruction, may_combine_partial_sharding); } else { - return MaybeImproveInstructionSharding(operand->sharding(), - instruction); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); } } case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: { - auto propagate_slicing = [instruction]() { + auto propagate_slicing = [&]() { const HloInstruction* operand = instruction->opcode() == HloOpcode::kDynamicSlice ? instruction->operand(0) @@ -872,7 +984,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (operand->sharding().IsReplicated()) { return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + instruction, + may_combine_partial_sharding); } const auto& tile_assignment = operand->sharding().tile_assignment(); @@ -883,10 +996,10 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } } - return MaybeImproveInstructionSharding(operand->sharding(), - instruction); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); }; - auto propagate_base = [instruction]() { + auto propagate_base = [&]() { if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) { return false; } @@ -894,7 +1007,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } return MaybeImproveInstructionSharding( - instruction->operand(0)->sharding(), instruction); + instruction->operand(0)->sharding(), instruction, + may_combine_partial_sharding); }; return propagate_slicing() || propagate_base(); } @@ -903,15 +1017,17 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (IsSpatiallyPartitioned(instruction->operand(1))) { HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( instruction->operand(1)->sharding(), instruction); - changed |= MaybeImproveInstructionSharding(new_sharding, instruction); + changed |= MaybeImproveInstructionSharding( + std::move(new_sharding), instruction, may_combine_partial_sharding); } if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { auto maybe_from_data = hlo_sharding_util::GatherOutputShardingFromDataOperand( instruction->operand(0)->sharding(), *instruction); if (maybe_from_data) { - changed |= - MaybeImproveInstructionSharding(*maybe_from_data, instruction); + changed |= MaybeImproveInstructionSharding( + std::move(*maybe_from_data), instruction, + may_combine_partial_sharding); } } return changed; @@ -920,7 +1036,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, bool changed = false; if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { changed |= MaybeImproveInstructionSharding( - instruction->operand(0)->sharding(), instruction); + instruction->operand(0)->sharding(), instruction, + may_combine_partial_sharding); } if (!IsSpatiallyPartitioned(instruction->operand(1)) && !IsSpatiallyPartitioned(instruction->operand(2))) { @@ -931,12 +1048,13 @@ bool InferShardingFromOperands(HloInstruction* instruction, hlo_sharding_util::ScatterOutputShardingFromUpdate( instruction->operand(2)->sharding(), *instruction); if (maybe_from_update) { - changed |= - MaybeImproveInstructionSharding(*maybe_from_update, instruction); + changed |= MaybeImproveInstructionSharding( + std::move(*maybe_from_update), instruction, + may_combine_partial_sharding); } } - changed |= MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + changed |= MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, may_combine_partial_sharding); return changed; } case HloOpcode::kWhile: { @@ -948,50 +1066,143 @@ bool InferShardingFromOperands(HloInstruction* instruction, sharding = MergeForMoreSpecificSharding(sharding, instruction->sharding()); } - return MaybeImproveInstructionSharding(sharding, instruction); + return MaybeImproveInstructionSharding(std::move(sharding), instruction, + may_combine_partial_sharding); } default: { + if (instruction->IsElementwise() && may_combine_partial_sharding) { + bool changed = false; + for (auto operand : instruction->operands()) { + if (IsSpatiallyPartitioned(operand)) { + changed |= MaybeImproveInstructionSharding( + operand->sharding(), instruction, may_combine_partial_sharding); + } + } + return changed; + } const HloInstruction* operand = PickRepresentativeOperand(instruction); if (!operand || !IsSpatiallyPartitioned(operand)) { return false; } - return MaybeImproveInstructionSharding(operand->sharding(), instruction); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); } } return false; } +HloSharding InferDotOperandSharding( + const HloInstruction* instruction, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + int64 operand_index, bool may_combine_partial_sharding) { + auto operand = instruction->operand(operand_index); + auto other = instruction->operand(1 - operand_index); + std::vector output_dims_to_replicate; + std::vector other_operand_dims_to_replicate; + for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims + : dnums.lhs_non_contracting_dims) { + output_dims_to_replicate.push_back(dim.output); + other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs + : dim.lhs); + } + // If this dot is interpreted from a conv, then contracting dims may have + // corresponding spatial dimensions in the output, and this operand's + // non-contracting dims may have corresponding spatial dims in the other + // operand. + for (const auto& dim : dnums.contracting_dims) { + if (dim.output >= 0) { + output_dims_to_replicate.push_back(dim.output); + } + } + for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + int64 other_dim = operand_index == 0 ? dim.rhs : dim.lhs; + if (other_dim >= 0) { + other_operand_dims_to_replicate.push_back(other_dim); + } + } + auto output_other_dims_replicated = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + instruction->sharding(), output_dims_to_replicate); + std::vector output_to_operand_dims(instruction->shape().rank(), -1); + std::vector operand_to_output_dims(operand->shape().rank(), -1); + for (const auto& dim : dnums.batch_dims) { + output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + } + for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + } + auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( + output_other_dims_replicated, output_to_operand_dims, + operand_to_output_dims); + if (IsSpatiallyPartitioned(other)) { + auto other_operand_dims_replicated = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + other->sharding(), other_operand_dims_to_replicate); + std::vector other_to_operand_dims(other->shape().rank(), -1); + std::vector operand_to_other_dims(operand->shape().rank(), -1); + for (const auto& dim : dnums.batch_dims) { + other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + operand_index == 0 ? dim.rhs : dim.lhs; + } + for (const auto& dim : dnums.contracting_dims) { + other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + operand_index == 0 ? dim.rhs : dim.lhs; + } + HloSharding sharding_from_other = + *hlo_sharding_util::TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { + sharding = std::move(sharding_from_other); + } + } + return sharding; +} + // Return the sharding that should be propagated from user to instruction. absl::optional GetShardingFromUser( const HloInstruction& instruction, const HloInstruction& user, - bool aggressive_prop, bool is_spmd) { + int64 aggressiveness, bool is_spmd) { + if (!CanPropagateThroughAtAgressiveLevel(user, aggressiveness)) { + return absl::nullopt; + } if (!IsSpatiallyPartitioned(&user)) { return absl::nullopt; } + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; switch (user.opcode()) { case HloOpcode::kBroadcast: { if (user.sharding().IsReplicated()) { return user.sharding(); } - // Only support when none of the partitioned dimensions in the broadcast - // output belong to new dimensions. + std::vector dims_to_replicate; + bool needs_replication = false; for (int64 i = 0; i < user.shape().rank(); ++i) { - if (user.sharding().tile_assignment().dim(i) > 1 && - absl::c_count(user.dimensions(), i) == 0) { - return absl::nullopt; + if (absl::c_count(user.dimensions(), i) == 0) { + dims_to_replicate.push_back(i); + if (user.sharding().tile_assignment().dim(i) > 1) { + needs_replication = true; + } } } - - // The instruction (operand of broadcast) will be tiled the same way - // as the output. - std::vector target_tile_assignment_dimensions; - for (int64 output_dim : user.dimensions()) { - target_tile_assignment_dimensions.push_back( - user.sharding().tile_assignment().dim(output_dim)); + // If not SPMD, only support when none of the partitioned dimensions in + // the broadcast output belong to new dimensions. + if (!is_spmd && needs_replication) { + return absl::nullopt; } - Array new_tile_assignment = user.sharding().tile_assignment(); - new_tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(new_tile_assignment); + return hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + user.sharding(), dims_to_replicate), + dims_to_replicate); } case HloOpcode::kConcatenate: { if (user.sharding().IsReplicated()) { @@ -1036,64 +1247,11 @@ absl::optional GetShardingFromUser( return HloSharding::Tile(new_tile_assignment); } case HloOpcode::kConvolution: { - if (auto dot_dims = - dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) { - const auto& dnums = user.convolution_dimension_numbers(); - auto partitioned_only_along_non_trivial_dims = - [&](const HloSharding& sharding, - std::vector& - dims) { - if (sharding.IsTileMaximal()) { - return false; - } - int64 partition_count = 1; - for (const auto& dim : dims) { - if (user.shape().dimensions(dim.output) == 1) { - continue; - } - partition_count *= sharding.tile_assignment().dim(dim.output); - } - return partition_count == - sharding.tile_assignment().num_elements(); - }; - // If output is partitioned only along the batch dimensions, or only - // along the non-contracting dimensions, propagate the sharding to the - // operand. - if (&instruction == user.operand(0) && - (partitioned_only_along_non_trivial_dims(user.sharding(), - dot_dims->batch_dims) || - partitioned_only_along_non_trivial_dims( - user.sharding(), dot_dims->lhs_non_contracting_dims))) { - std::vector lhs_to_output_indices(user.shape().rank()); - lhs_to_output_indices[dnums.input_batch_dimension()] = - dnums.output_batch_dimension(); - lhs_to_output_indices[dnums.input_feature_dimension()] = - dnums.output_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = - dnums.output_spatial_dimensions(i); - } - return hlo_sharding_util::TransposeSharding(user.sharding(), - lhs_to_output_indices); - } - if (&instruction == user.operand(1) && - (partitioned_only_along_non_trivial_dims(user.sharding(), - dot_dims->batch_dims) || - partitioned_only_along_non_trivial_dims( - user.sharding(), dot_dims->rhs_non_contracting_dims))) { - std::vector rhs_to_output_indices(user.shape().rank()); - rhs_to_output_indices[dnums.kernel_input_feature_dimension()] = - dnums.output_batch_dimension(); - rhs_to_output_indices[dnums.kernel_output_feature_dimension()] = - dnums.output_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - rhs_to_output_indices[dnums.kernel_spatial_dimensions(i)] = - dnums.output_spatial_dimensions(i); - } - return hlo_sharding_util::TransposeSharding(user.sharding(), - rhs_to_output_indices); - } + auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user); + if (dot_dims.conv_spatial_dims.empty()) { + int64 op_idx = user.operand_index(&instruction); + return InferDotOperandSharding(&user, dot_dims, op_idx, + may_combine_partial_sharding); } return absl::nullopt; } @@ -1175,33 +1333,10 @@ absl::optional GetShardingFromUser( return new_sharding; } case HloOpcode::kDot: { - if (user.sharding().IsReplicated()) { - return user.sharding(); - } - auto& dim_numbers = user.dot_dimension_numbers(); int64 op_idx = user.operand_index(&instruction); - // Batch dimensions are the same on lhs and rhs for dot operations. - int64 num_batch_dims = dim_numbers.lhs_batch_dimensions_size(); - int64 num_spatial_dims = - instruction.shape().dimensions_size() - num_batch_dims; - if (num_spatial_dims == 1) { - // This is the vector of a matrix x vector operation -> replicate, - // since tiling on the vector would necessarily be on the contracting - // dimension, which we don't support. - CHECK_EQ(op_idx, 1); - return HloSharding::Replicate(); - } - // Instruction is necessarily a matrix because it is one of the operands - // of a matrix x matrix operation. - CHECK_EQ(num_spatial_dims, 2); - // Propagate tile sharding to the bigger operand, and replicate the other. - auto other_op = user.operand(op_idx ^ 1); - if (ShapeUtil::ByteSizeOf(instruction.shape()) > - ShapeUtil::ByteSizeOf(other_op->shape())) { - return user.sharding(); - } else { - return HloSharding::Replicate(); - } + auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user); + return InferDotOperandSharding(&user, dnums, op_idx, + may_combine_partial_sharding); } case HloOpcode::kReduce: { if (instruction.shape().rank() == 0) { @@ -1216,10 +1351,11 @@ absl::optional GetShardingFromUser( return user_sharding; } std::vector target_tile_assignment_dimensions( - instruction.shape().rank()); + instruction.shape().rank() + + (user_sharding.ReplicateOnLastTileDim() ? 1 : 0)); const auto& dimensions = user.dimensions(); int64 next_output_dim = 0; - for (int64 i = 0; i < instruction.shape().rank(); ++i) { + for (int64 i = 0; i < target_tile_assignment_dimensions.size(); ++i) { if (absl::c_find(dimensions, i) == dimensions.end()) { target_tile_assignment_dimensions[i] = user_sharding.tile_assignment().dim(next_output_dim++); @@ -1229,7 +1365,9 @@ absl::optional GetShardingFromUser( } auto tile_assignment = user_sharding.tile_assignment(); tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(tile_assignment); + return user_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } case HloOpcode::kSort: { if (user.sharding().IsTuple()) { @@ -1299,17 +1437,21 @@ absl::optional GetShardingFromUser( // false otherwise. bool InferShardingFromUsers(HloInstruction* instruction, const ComputationMap& computation_map, - bool aggressive_prop, bool is_spmd) { + int64 aggressiveness, bool is_spmd) { + if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) { + return false; + } if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { return false; } bool improved_sharding = false; + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { absl::optional user_sharding = - GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); + GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd); if (user_sharding) { - improved_sharding |= - MaybeImproveInstructionSharding(*user_sharding, instruction); + improved_sharding |= MaybeImproveInstructionSharding( + std::move(*user_sharding), instruction, may_combine_partial_sharding); } } return improved_sharding; @@ -1579,10 +1721,12 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // strictly improve the sharding of the graph and it can't be improved // indefinitely. int64 iterations = 0; - auto run_to_fix_point = [&](bool aggressive_prop) { - bool changed = true; - while (changed) { - changed = false; + auto run_to_fix_point = [&](int64 aggressiveness) { + absl::flat_hash_set already_inferred_from_operands; + absl::flat_hash_set already_inferred_from_users; + bool changed_last_iter = true; + while (changed_last_iter) { + changed_last_iter = false; int64 inferred_from_operand_counter = 0; int64 inferred_from_user_counter = 0; int64 instruction_counter = 0; @@ -1595,42 +1739,55 @@ StatusOr ShardingPropagation::Run(HloModule* module) { for (const HloInstruction* instruction : instructions) { already_sharded_counter += (instruction->has_sharding() ? 1 : 0); } - - // Remove the instructions where the sharding was provided from the - // outside so we don't modify them. - instructions.erase( - std::remove_if(instructions.begin(), instructions.end(), - [&](HloInstruction* instruction) { - return provided_shardings.contains(instruction); - }), - instructions.end()); - // First iterate the HLO graph in post order taking shardings from // operands. for (HloInstruction* instruction : instructions) { + if (already_inferred_from_operands.contains(instruction) || + provided_shardings.contains(instruction)) { + continue; + } + already_inferred_from_operands.insert(instruction); if (InferShardingFromOperands(instruction, computation_map, is_spmd_, - aggressive_prop)) { + aggressiveness)) { ++inferred_from_operand_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (forward-pass): " << instruction->ToString(); maybe_computation_propagation(instruction); + for (auto operand : instruction->operands()) { + already_inferred_from_users.erase(operand); + } + for (auto user : instruction->users()) { + already_inferred_from_operands.erase(user); + } + changed_last_iter = true; } } // Then iterate the HLO graph in reverse post order taking shardings // from users. for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { - if (InferShardingFromUsers(*it, computation_map, aggressive_prop, + if (already_inferred_from_users.contains(*it) || + provided_shardings.contains(*it)) { + continue; + } + already_inferred_from_users.insert(*it); + if (InferShardingFromUsers(*it, computation_map, aggressiveness, is_spmd_)) { ++inferred_from_user_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); maybe_computation_propagation(*it); + for (auto operand : (*it)->operands()) { + already_inferred_from_users.erase(operand); + } + for (auto user : (*it)->users()) { + already_inferred_from_operands.erase(user); + } + changed_last_iter = true; } } } - any_changed |= changed; VLOG(1) << "Sharding propagation iteration " << iterations << ";"; VLOG(1) << " total instructions: " << instruction_counter; VLOG(1) << " instructions already sharded: " << already_sharded_counter; @@ -1638,11 +1795,13 @@ StatusOr ShardingPropagation::Run(HloModule* module) { << inferred_from_operand_counter; VLOG(1) << " shardings inferred from users: " << inferred_from_user_counter; + VLOG(1) << " aggressiveness: " << aggressiveness; ++iterations; } }; - run_to_fix_point(false); - run_to_fix_point(true); + for (int64 aggressiveness = 0; aggressiveness < 4; ++aggressiveness) { + run_to_fix_point(aggressiveness); + } VLOG(1) << "Sharding propagation completed after " << iterations << " iterations"; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index 594130daf0b..8c4d8fc24ff 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -65,22 +65,6 @@ ENTRY %elementwise { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } -TEST_F(ShardingPropagationTest, BroadcastForwardPassNoSharding) { - const char* const hlo_string = R"( -HloModule module -ENTRY %broadcast { - %param0 = f32[7,11]{1,0} parameter(0), - sharding={devices=[2,2]0,1,2,3} - %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={1,2} - ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - ShardingPropagation().Run(module.get())); - EXPECT_FALSE(changed); -} - // Regression Test for b/129569657. TEST_F(ShardingPropagationTest, BroadcastForwardPass) { const char* const hlo_string = R"( @@ -118,6 +102,25 @@ ENTRY %broadcast { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, BroadcastForwardPartial) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048]parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1} + ROOT %copy = f32[3,2048,3] copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, BroadcastUser) { const char* const hlo_string = R"( HloModule module @@ -136,6 +139,25 @@ ENTRY %broadcast { op::Sharding("{devices=[2,4]0,1,2,3,4,5,6,7}")); } +TEST_F(ShardingPropagationTest, BroadcastUserPartial) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[24,8]{0,1} parameter(0) + %copy = f32[24,8]{0,1} copy(%param0) + ROOT %broadcast = f32[4,24,6,8] broadcast(%copy), dimensions={1,3}, + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, MaximalReduceForwardPass) { const char* const hlo_string = R"( HloModule module @@ -184,6 +206,78 @@ ENTRY %reduce { op::Sharding("{devices=[2,2]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, ReducePartiallyOnTiledDims) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0), sharding={devices=[2,2]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%param0, %init), dimensions={0}, to_apply=%add + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,2]0,2,1,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ReducePartiallyOnTiledDims2) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0), sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%param0, %init), dimensions={0}, to_apply=%add + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ReducePartiallyBackward) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0) + %input = f32[8,8] copy(%param0) + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%input, %init), dimensions={0}, to_apply=%add, + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ShardedTupleReduceForwardAndBackwardPass) { const char* const hlo_string = R"( HloModule module @@ -420,6 +514,26 @@ ENTRY %pad { op::Sharding("{devices=[2,2]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %pad { + %input = f32[11,17]{1,0} parameter(0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %pad_value = f32[] parameter(1) + %pad = f32[27,51]{1,0} pad(%input, %pad_value), padding=2_4_1x1_1_2 + ROOT %copy = f32[27,51]{1,0} copy(%pad) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "pad"), + op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ShardedPreferredOverReplicated) { const char* const hlo_string = R"( HloModule module @@ -446,6 +560,43 @@ ENTRY %replicated { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PartialReplicateReshapeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[1430,1]{1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %reshape = f32[10,11,13]{2,1,0} reshape(%param0) + ROOT %copy = f32[10,11,13]{2,1,0} copy(%reshape) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "reshape"), + op::Sharding("{devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialReplicateReshapeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[2002,1]{1,0} parameter(0) + %copy = f32[2002,1]{1,0} copy(f32[2002,1]{1,0} %param0) + ROOT %reshape = f32[14,11,13]{2,1,0} reshape(%copy), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, DontShardTuplesIfAllInputIsMaximal) { const char* const hlo_string = R"( HloModule module @@ -506,6 +657,25 @@ ENTRY %slice { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, PartialReplicatedStridedSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %slice { + %param = f32[17,13]{1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %slice = f32[7,5]{1,0} slice(%param), slice={[1:15:2], [5:10:1]} + ROOT %tuple = (f32[7,5]{1,0}) tuple(%slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "slice"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ReduceWindowBackwardPass) { const char* const hlo_string = R"( HloModule module @@ -565,13 +735,15 @@ ENTRY conv { %rhs = f32[2,2,1]{2,1,0} parameter(1) %conv = f32[3,2,3]{2,1,0} convolution(%lhs, %rhs), window={size=1}, dim_labels=bf0_oi0->bf0 - ROOT %tuple = f32[3,2,3]{2,1,0} tuple(%conv) + ROOT %tuple = (f32[3,2,3]{2,1,0}) tuple(%conv) })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module.get())); - EXPECT_FALSE(changed); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{replicated}")); } TEST_F(ShardingPropagationTest, ConvolutionDifferentDimensionNumbers) { @@ -937,7 +1109,7 @@ ENTRY %conv { %p0_copy_0 = f32[8,256,128] copy(%param.0), sharding={devices=[1,4,1]0,1,2,3} %p1_copy_0 = f32[8,128,512] copy(%param.1), - sharding={devices=[1,2,2]0,1,2,3} + sharding={devices=[1,1,4]0,1,2,3} %p2_copy = f32[8,128] copy(%param.2) %dot_prop_rhs = f32[8,256,512] dot(%p0_copy_0, %p1_copy_0), lhs_batch_dims={0}, rhs_batch_dims={0}, @@ -966,16 +1138,18 @@ ENTRY %conv { ShardingPropagation().Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT(FindInstruction(module.get(), "dot_prop_rhs"), - op::Sharding("{devices=[1,2,2]0,1,2,3}")); + op::Sharding("{devices=[1,1,4]0,1,2,3}")); EXPECT_THAT(FindInstruction(module.get(), "dot_prop_lhs"), - op::Sharding("{devices=[1,2,2]0,1,2,3}")); + op::Sharding("{devices=[1,4,1]0,1,2,3}")); EXPECT_THAT(FindInstruction(module.get(), "dot_mat_vec"), op::Sharding("{devices=[1,4]0,1,2,3}")); - EXPECT_THAT(FindInstruction(module.get(), "p0_copy_1"), - op::Sharding("{replicated}")); - EXPECT_THAT(FindInstruction(module.get(), "p1_copy_1"), - op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT( + FindInstruction(module.get(), "p0_copy_1"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}")); + EXPECT_THAT( + FindInstruction(module.get(), "p1_copy_1"), + op::Sharding("{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}")); EXPECT_THAT(FindInstruction(module.get(), "dot_back_prop_rhs"), op::Sharding("{devices=[1,2,2]0,1,2,3}")); } @@ -1004,6 +1178,146 @@ ENTRY %conv { op::Sharding("{devices=[2,2,1]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, DotMergeOperands) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,256,512] parameter(0), + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %p1 = f32[8,128,512] parameter(1), + sharding={devices=[2,2,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate} + %dot = f32[8,256,128] dot(%p0, %p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + ROOT %copy = f32[8,256,128] copy(%dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dot"), + op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, DotMergeOperands2) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,256,512] parameter(0), sharding={devices=[2,2,2]0,1,2,3,4,5,6,7} + %p1 = f32[8,128,512] parameter(1), sharding={devices=[2,2,2]0,1,2,3,4,5,6,7} + %dot = f32[8,256,128] dot(%p0, %p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + ROOT %copy = f32[8,256,128] copy(%dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "dot"), + op::Sharding( + "{devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, BackwardDotFromContracting) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,256,512] parameter(0), sharding={devices=[2,2,2]0,1,2,3,4,5,6,7} + %p1 = f32[8,128,512] parameter(1) + %copy1 = f32[8,128,512] copy(%p1) + %dot = f32[8,256,128] dot(%p0, %copy1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy = f32[8,256,128] copy(%dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy1"), + op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, ConvAsDotOnTrivialDims) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %lhs = f32[128,1,1,1001] parameter(0), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[1,1,1024,1001] parameter(1), sharding={devices=[1,2,1,1]0,1} + %convolution = f32[128,1,1,1024] convolution(%lhs, %rhs), + window={size=1x1 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f + ROOT %copy = f32[128,1,1,1024] copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "convolution"), + op::Sharding("{devices=[1,1,2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ConvAsDotOnTrivialDimsBackward) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[128,5,5,128] parameter(0) + %lhs = f32[128,5,5,128] copy(%p0) + %p1 = f32[5,5,128,768] parameter(1) + %rhs = f32[5,5,128,768] copy(%p1) + %convolution = f32[128,1,1,768] convolution(%lhs, %rhs), window={size=5x5}, + dim_labels=b01f_01io->b01f, sharding={devices=[1,2,1,1]0,1} + ROOT %copy = f32[128,1,1,768] copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "rhs"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, + ConvolutionFilterIFOFPartitionedInputPartialReplicate) { + const char* const hlo_string = R"( + HloModule module + +ENTRY entry { + %lhs = f32[128,112,112,12] parameter(0) + %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[7,7,12,64] parameter(1) + %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs), + sharding={devices=[1,1,2,2]0,1,2,3} + %conv = f32[128,56,56,64] convolution( + f32[128,112,112,12] %lhs.copy, + f32[7,7,12,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f + ROOT %copy = f32[128,56,56,64] copy(conv) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + VLOG(1) << module->ToString(); + + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ConcatFromUserUnshardedDim) { const char* const hlo_string = R"( HloModule module @@ -1155,15 +1469,15 @@ ENTRY entry { EXPECT_THAT(FindInstruction(module.get(), "ttr"), op::Sharding("{devices=[2,1]0,1}")); EXPECT_THAT(FindInstruction(module.get(), "tr"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "fp"), op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "fgte"), op::Sharding("{devices=[1,3]0,1,2}")); EXPECT_THAT(FindInstruction(module.get(), "fr"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "conditional"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); } TEST_F(ShardingPropagationTest, TupleFromUser) { @@ -1515,6 +1829,28 @@ ENTRY entry { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, GatherFromIndex_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherFromDataOperand) { const char* hlo_string = R"( HloModule module @@ -1536,6 +1872,28 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherFromDataOperand_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, GatherToIndex) { const char* hlo_string = R"( HloModule module @@ -1557,6 +1915,98 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToIndex_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1) + %indices = s32[3] copy(%p1) + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, GatherToIndex2) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,1000,2] parameter(1) + %indices = s32[2,1000,2] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,1000,2] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1,4}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, GatherToIndex2_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,1000,2] parameter(1) + %indices = s32[2,1000,2] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,1000,2] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1,4}, + sharding={devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, GatherToIndex3) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,2,1000] parameter(1) + %indices = s32[2,2,1000] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + TEST_F(ShardingPropagationTest, GatherToDataOperand) { const char* hlo_string = R"( HloModule module @@ -1578,6 +2028,27 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToDataOperand_PartialReplicate) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, DataOperandToScatter) { const char* const hlo_string = R"( HloModule module @@ -1609,6 +2080,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, UpdateOperandToScatter) { const char* const hlo_string = R"( HloModule module @@ -1640,6 +2143,70 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToDataOperand) { const char* const hlo_string = R"( HloModule module @@ -1671,6 +2238,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1), sharding={replicated} + %p2 = f32[3,9] parameter(2) + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) { const char* const hlo_string = R"( HloModule module @@ -1733,6 +2332,38 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1), sharding={replicated} + %indices = s32[3] copy(%p1) + %updates = f32[3,9] parameter(2), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) { const char* const hlo_string = R"( HloModule module @@ -1764,5 +2395,130 @@ ENTRY entry { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + %p2 = f32[3,9] parameter(2), sharding={replicated} + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0), sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %p1 = f32[2,9] parameter(1), sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + %lhs = f32[2,9] copy(%p0) + %rhs = f32[2,9] copy(%p1) + %add = f32[2,9] add(%lhs, %rhs) + ROOT %copy = f32[2,9] copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs"), + op::Sharding("{devices=[2,2]0,2,1,3}")); + EXPECT_THAT(FindInstruction(module.get(), "rhs"), + op::Sharding("{devices=[2,2]0,2,1,3}")); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingOnElementwise2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0), sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %p1 = f32[2,9] parameter(1), sharding={devices=[2,1,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} + %lhs = f32[2,9] copy(%p0) + %rhs = f32[2,9] copy(%p1) + %add = f32[2,9] add(%lhs, %rhs) + ROOT %copy = f32[2,9] copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "lhs"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); + EXPECT_THAT( + FindInstruction(module.get(), "rhs"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); + EXPECT_THAT( + FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingTransposeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %transpose = f32[11,13,7]{2,1,0} transpose(%param), dimensions={1,2,0} + ROOT %copy = f32[11,13,7]{2,1,0} copy(%transpose) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "transpose"), + op::Sharding( + "{devices=[1,2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingTransposeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0) + %copy = f32[7,11,13]{2,1,0} copy(%param) + ROOT %transpose = f32[11,13,7]{2,1,0} transpose(%copy), dimensions={1,2,0}, + sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "copy"), + op::Sharding( + "{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index ce19934bb88..5fd7b7850cf 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -23,6 +23,7 @@ cc_library( "spmd_partitioner_util.cc", ], hdrs = [ + "convolution_handler.h", "spmd_partitioner.h", "spmd_partitioner_util.h", ], @@ -48,6 +49,7 @@ cc_library( "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/core:lib", "//tensorflow/core/platform:numbers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -73,3 +75,16 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "schedule_aware_all_gather_cse", + srcs = ["schedule_aware_all_gather_cse.cc"], + hdrs = ["schedule_aware_all_gather_cse.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc index 01d7ea2ff14..81419c55109 100644 --- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h" + #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" @@ -32,24 +34,32 @@ limitations under the License. namespace xla { namespace spmd { + namespace { -// Partition convolution. -StatusOr PartitionConvolution( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, - const SpmdPartitionerOptions& options, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b); - -// Partition convolution with only paralell dims are tiled -StatusOr PartitionConvolutionWithParallelDimension( +// Partition convolution with batch group count. +StatusOr PartitionConvolutionWithBatchGroupCount( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const Window& conv_window, HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + if (original_hlo->batch_group_count() == 1 || + original_hlo->batch_group_count() < num_partitions) { + return nullptr; + } const auto& dnums = original_hlo->convolution_dimension_numbers(); + // Only supports batch_group_size equals input_batch_size case. + const int64 input_batch_size = + lhs.base_shape().dimensions(dnums.input_batch_dimension()); + const int64 kernel_output_feature_size = + rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension()); + if (input_batch_size != kernel_output_feature_size || + original_hlo->batch_group_count() != input_batch_size) { + return nullptr; + } + + // Map RHS indices to LHS indices. std::vector rhs_to_lhs_indices(output_base_shape.rank()); rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = dnums.input_batch_dimension(); @@ -59,73 +69,167 @@ StatusOr PartitionConvolutionWithParallelDimension( rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = dnums.input_spatial_dimensions(i); } + + // Map LHS indices to RHS indices. std::vector lhs_to_rhs_indices(output_base_shape.rank()); for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; } + + // Map LHS indices to output indices. + std::vector lhs_to_output_indices(lhs.base_shape().rank(), -1); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_feature_dimension(); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + + // Align LHS or RHS to other operand if input batch dim or kernel output + // feature dim is partitioned. auto aligned_rhs_sharding = hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); - // Handling cases where all the partitioned dimensions are parallel - // dimensions. - int64 lhs_parallel_dim_partitions = 1; - int64 rhs_parallel_dim_partitions = 1; - std::vector parallel_spatial_dims; - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dim = dnums.input_spatial_dimensions(i); - int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); - const auto& wd = conv_window.dimensions(i); - int64 rhs_dim = dnums.kernel_spatial_dimensions(i); - if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { - parallel_spatial_dims.emplace_back(i); - lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim); - rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim); - } - } - bool lhs_partition_dims_are_parallel = - (lhs_parallel_dim_partitions == num_partitions); - bool rhs_partition_dims_are_parallel = - (rhs_parallel_dim_partitions == num_partitions); - - // If there is a parallel dim and all the partitioned dimensions are parallel - // dimensions in either LHS or RHS, simply create partitioned convolutions. - if (parallel_spatial_dims.empty() || ((!lhs_partition_dims_are_parallel) && - (!rhs_partition_dims_are_parallel))) { + bool lhs_batch_dim_is_partitioned = + (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) == + num_partitions); + bool rhs_output_feature_dim_is_partitioned = + (ShardCountAtDim(rhs.sharding(), + dnums.kernel_output_feature_dimension()) == + num_partitions); + if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) { return nullptr; } - // Reshard LHS or RHS to partition at parallel dimensions as the other - // operand. - if (lhs_partition_dims_are_parallel) { + // Reshard LHS or RHS to partition at batch dimension or output feature + // dimension as the other operand. + if (lhs_batch_dim_is_partitioned) { + rhs = rhs.Reshard(aligned_rhs_sharding); + } else { + lhs = lhs.Reshard(aligned_lhs_sharding); + } + // Align output sharding after LHS and RHS sharding are consistent. + auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( + lhs.sharding(), lhs_to_output_indices); + + // Get LHS and RHS sharded shape. + auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding()); + auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding()); + const int64 batch_group_count = + CeilOfRatio(original_hlo->batch_group_count(), num_partitions); + // Create partitioned convolution. + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(), + batch_group_count, conv_window, dnums)); + auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, lhs.hlo(), rhs.hlo(), + original_hlo->feature_group_count(), batch_group_count, conv_window, + dnums, original_hlo->precision_config())); + sharded_conv->set_sharding(aligned_output_sharding); + return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution with feature group count. +StatusOr PartitionConvolutionWithFeatureGroupCount( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + if (original_hlo->feature_group_count() == 1 || + original_hlo->feature_group_count() < num_partitions) { + return nullptr; + } + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + const int64 input_feature_size = + lhs.base_shape().dimensions(dnums.input_feature_dimension()); + const int64 kernel_output_feature_size = + rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension()); + if (input_feature_size != kernel_output_feature_size || + input_feature_size % original_hlo->feature_group_count() != 0) { + return nullptr; + } + + // Align RHS indices to LHS. + std::vector rhs_to_lhs_indices(output_base_shape.rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_feature_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + + // Align LHS indices to RHS. + std::vector lhs_to_rhs_indices(output_base_shape.rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + + // Align LHS indices to output. + std::vector lhs_to_output_indices(output_base_shape.rank()); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + + // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is + // partitioned. + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + bool lhs_feature_dim_is_partitioned = + (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) == + num_partitions); + bool rhs_output_feature_dim_is_partitioned = + (ShardCountAtDim(rhs.sharding(), + dnums.kernel_output_feature_dimension()) == + num_partitions); + if (!lhs_feature_dim_is_partitioned && + !rhs_output_feature_dim_is_partitioned) { + return nullptr; + } + // Reshard LHS or RHS to partition at input feature dimension or output + // feature dimension as the other operand. + if (lhs_feature_dim_is_partitioned) { rhs = rhs.Reshard(aligned_rhs_sharding); } else { lhs = lhs.Reshard(aligned_lhs_sharding); } - // Get LHS and RHS sharded shape. + // Align output sharding after LHS and RHS sharding are consistent. + auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( + lhs.sharding(), lhs_to_output_indices); + auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding()); auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding()); + int64 feature_group_count = + CeilOfRatio(original_hlo->feature_group_count(), num_partitions); - // Update convolution window. - auto new_window = conv_window; - for (const auto& spatial_dim : parallel_spatial_dims) { - auto wd = new_window.mutable_dimensions(spatial_dim); - wd->set_size(lhs_shard_shape.dimensions( - dnums.input_spatial_dimensions(spatial_dim))); - wd->set_stride(std::max(1, wd->size() - 1)); - wd->set_base_dilation(wd->size()); - } TF_ASSIGN_OR_RETURN( Shape sharded_conv_shape, ShapeInference::InferConvolveShape( - lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(), - original_hlo->batch_group_count(), new_window, dnums)); + lhs_shard_shape, rhs_shard_shape, feature_group_count, + original_hlo->batch_group_count(), conv_window, dnums)); auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( - sharded_conv_shape, lhs.hlo(), rhs.hlo(), - original_hlo->feature_group_count(), original_hlo->batch_group_count(), - new_window, dnums, original_hlo->precision_config())); - sharded_conv->set_sharding(original_hlo->sharding()); + sharded_conv_shape, lhs.hlo(), rhs.hlo(), feature_group_count, + original_hlo->batch_group_count(), conv_window, dnums, + original_hlo->precision_config())); + sharded_conv->set_sharding(aligned_output_sharding); return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) .Reshard(output_sharding) .hlo(); @@ -214,7 +318,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( int64 lhs_dimension = dnums.input_spatial_dimensions(i); int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); - auto wd = conv_window.dimensions(i); + const auto& wd = conv_window.dimensions(i); if (wd.base_dilation() != 1 || wd.window_reversal()) { return nullptr; } @@ -260,7 +364,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( // Calculate the left and right halo sizes as described in the comments // above. It calculcates the halo sizes with dilation, so we apply // CeilOfRatio({left,right}_halo_size, window_dilation). - auto wd = conv_window.dimensions(i); + const auto& wd = conv_window.dimensions(i); int64 padding_low = wd.padding_low(); int64 padding_high = wd.padding_high(); int64 base = lhs.base_shape().dimensions(lhs_dimension); @@ -430,7 +534,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; } - Window window = conv_window; + const Window& window = conv_window; std::vector reversed_rhs_dims; for (int64 i = 0; i < window.dimensions_size(); ++i) { if (window.dimensions(i).window_reversal()) { @@ -505,7 +609,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( int64 lhs_dimension = dnums.input_spatial_dimensions(i); int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); - auto wd = window.dimensions(i); + const auto& wd = window.dimensions(i); if (wd.base_dilation() != 1) { // TODO(wangtao): support parallel dim if it is replicate here. return nullptr; @@ -540,7 +644,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( // Calculate the left and right halo sizes as described in the comments // above. - auto wd = window.dimensions(i); + const auto& wd = window.dimensions(i); int64 padding_low = wd.padding_low(); int64 padding_high = wd.padding_high(); int64 base = lhs.base_shape().dimensions(lhs_dimension); @@ -692,116 +796,6 @@ StatusOr PartitionConvolutionTiledOutput( shard_shape.dimensions())); } -StatusOr PartitionConvolutionGroupOnParallelDim( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, const ConvolutionDimsMapping& dims_mapping, - int64 num_partitions, const SpmdPartitionerOptions& options, - HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { - std::vector lhs_dims; - std::vector rhs_dims; - std::vector output_dims; - auto lhs_sharding_dims_adjusted_to_output = - lhs.sharding().IsReplicated() - ? std::vector(lhs.base_shape().rank(), 1) - : lhs.sharding().tile_assignment().dimensions(); - auto rhs_sharding_dims_adjusted_to_output = - rhs.sharding().IsReplicated() - ? std::vector(rhs.base_shape().rank(), 1) - : rhs.sharding().tile_assignment().dimensions(); - auto output_sharding_dims_adjusted_to_lhs = - output_sharding.tile_assignment().dimensions(); - bool lhs_rhs_dims_matching = true; - for (const auto& dim : dims_mapping.parallel_spatial_dims) { - lhs_dims.push_back(dim.lhs); - rhs_dims.push_back(dim.rhs); - output_dims.push_back(dim.output); - if (lhs_sharding_dims_adjusted_to_output[dim.lhs] != - rhs_sharding_dims_adjusted_to_output[dim.rhs]) { - lhs_rhs_dims_matching = false; - } - lhs_sharding_dims_adjusted_to_output[dim.lhs] = - output_sharding.tile_assignment().dim(dim.output); - rhs_sharding_dims_adjusted_to_output[dim.rhs] = - output_sharding.tile_assignment().dim(dim.output); - output_sharding_dims_adjusted_to_lhs[dim.output] = - lhs.sharding().tile_assignment().dim(dim.lhs); - } - auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); - auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); - auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); - if (lhs_rhs_dims_matching) { - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > - ShapeUtil::ByteSizeOf(rhs.base_shape())) { - rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); - rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); - } else { - lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); - lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); - } - auto reshaped_output_tiling = output_sharding.tile_assignment(); - reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); - output_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), - output_dims), - lhs_grouped); - } else { - auto reshaped_lhs_tiling = lhs.sharding().tile_assignment(); - reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output); - lhs_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims), - output_grouped); - lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); - auto reshaped_rhs_tiling = rhs.sharding().tile_assignment(); - reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output); - rhs_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims), - output_grouped); - rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); - } - - // Update LHS and RHS sharding and shape. - lhs.hlo()->set_sharding(lhs_grouped.sharding); - rhs.hlo()->set_sharding(rhs_grouped.sharding); - CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - lhs.state(), lhs_grouped.device_groups, b); - auto grouped_lhs_base_shape = - GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()); - auto grouped_lhs_shard_shape = - MakePartitionedShape(grouped_lhs_base_shape, lhs.sharding()); - // Update convolution window with the new shape - auto new_window = conv_window; - for (const auto& dim : dims_mapping.parallel_spatial_dims) { - auto wd = new_window.mutable_dimensions(dim.spatial); - wd->set_size(grouped_lhs_shard_shape.dimensions(dim.lhs)); - wd->set_stride(std::max(1, wd->size() - 1)); - wd->set_base_dilation(wd->size()); - } - - auto new_partition_id = - lhs.state().collective_ops_creator.create_partition_id(b); - TF_ASSIGN_OR_RETURN( - auto conv, - PartitionConvolution( - PartitionedHlo(lhs.hlo(), grouped_lhs_base_shape, - per_group_partitioner_state), - PartitionedHlo(rhs.hlo(), - GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), - per_group_partitioner_state), - GetPerGroupBaseShape(output_grouped, output_base_shape), - output_grouped.sharding, new_window, original_hlo, - num_partitions / output_grouped.device_groups.size(), options, - new_partition_id, module, b)); - // Reset the LHS sharding to the ungrouped one. - lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped)); - rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped)); - conv->set_sharding(UngroupSharding(output_grouped)); - return PartitionedHlo(conv, output_base_shape, lhs.state()) - .Reshard(output_sharding) - .hlo(); -} - // Partition convolution with only one kind of dims partitioned. StatusOr PartitionConvolutionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, @@ -811,13 +805,26 @@ StatusOr PartitionConvolutionBaseCase( HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); - // Case 1: Either RHS or LHS is only partitioned at parallel dimensions. - TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv, - PartitionConvolutionWithParallelDimension( - lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, num_partitions, b)); - if (parallel_partitioned_conv) { - return parallel_partitioned_conv; + // Case 1: Handle depthwise convolution with batch group count or + // feature group count. + if (original_hlo->batch_group_count() > 1) { + TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv, + PartitionConvolutionWithBatchGroupCount( + lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, num_partitions, b)); + if (parallel_partitioned_conv) { + return parallel_partitioned_conv; + } + } + + if (original_hlo->feature_group_count() > 1) { + TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv, + PartitionConvolutionWithFeatureGroupCount( + lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, num_partitions, b)); + if (parallel_partitioned_conv) { + return parallel_partitioned_conv; + } } // Case 2: both RHS and LHS are tiled. @@ -862,13 +869,15 @@ StatusOr PartitionConvolutionBaseCase( return nullptr; } +} // namespace + // Partition convolution. StatusOr PartitionConvolution( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const Window& conv_window, - HloInstruction* original_hlo, int64 num_partitions, - const SpmdPartitionerOptions& options, HloInstruction* partition_id, - HloModule* module, SpmdBuilder* b) { + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); TF_ASSIGN_OR_RETURN( @@ -880,133 +889,57 @@ StatusOr PartitionConvolution( return try_partitioned_conv; } - const auto& dnums = original_hlo->convolution_dimension_numbers(); - spmd::ConvolutionDimsMapping mapping; - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dim = dnums.input_spatial_dimensions(i); - int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); - const auto& wd = original_hlo->window().dimensions(i); - int64 rhs_dim = dnums.kernel_spatial_dimensions(i); - int64 output_dim = dnums.output_spatial_dimensions(i); - if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { - mapping.parallel_spatial_dims.emplace_back(); - mapping.parallel_spatial_dims.back().lhs = lhs_dim; - mapping.parallel_spatial_dims.back().rhs = rhs_dim; - mapping.parallel_spatial_dims.back().output = output_dim; - mapping.parallel_spatial_dims.back().spatial = i; - } else { - mapping.non_parallel_spatial_dims.emplace_back(); - mapping.non_parallel_spatial_dims.back().lhs = lhs_dim; - mapping.non_parallel_spatial_dims.back().rhs = rhs_dim; - mapping.non_parallel_spatial_dims.back().output = output_dim; - mapping.non_parallel_spatial_dims.back().spatial = i; - } - } - - // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. - auto get_partitions_for_dims = - [&](const HloSharding& sharding, - absl::Span dims, - int lhs_rhs_or_output) { - int64 partitions = 1; - if (sharding.IsTileMaximal()) { - return partitions; - } - for (const auto& dim : dims) { - if (lhs_rhs_or_output == 0) { - partitions *= sharding.tile_assignment().dim(dim.lhs); - } else if (lhs_rhs_or_output == 1) { - partitions *= sharding.tile_assignment().dim(dim.rhs); - } else { - CHECK_EQ(lhs_rhs_or_output, 2); - partitions *= sharding.tile_assignment().dim(dim.output); - } - } - return partitions; - }; - - const int64 lhs_parallel_spatial_partitions = - get_partitions_for_dims(lhs.sharding(), mapping.parallel_spatial_dims, 0); - const int64 rhs_parallel_spatial_partitions = - get_partitions_for_dims(rhs.sharding(), mapping.parallel_spatial_dims, 1); - const int64 output_parallel_spatial_partitions = get_partitions_for_dims( - original_hlo->sharding(), mapping.parallel_spatial_dims, 2); - - // Recursively partition on different types of dimensions. - // - // Case 1: Group partitions by parallel spatial dims. - if (lhs_parallel_spatial_partitions == rhs_parallel_spatial_partitions && - lhs_parallel_spatial_partitions == output_parallel_spatial_partitions && - lhs_parallel_spatial_partitions > 1) { - TF_ASSIGN_OR_RETURN(auto try_partitioned_conv, - PartitionConvolutionGroupOnParallelDim( - lhs, rhs, output_base_shape, output_sharding, - conv_window, original_hlo, mapping, num_partitions, - options, partition_id, module, b)); - if (try_partitioned_conv) { - return try_partitioned_conv; - } - } - return nullptr; } -} // namespace - Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { - auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); - if (dot_dnums) { - // Use HandleDotHelper() for convs that are actually einsums. - spmd::DotGeneralDimsMapping mapping; - for (const auto& dims : dot_dnums->batch_dims) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dims.lhs; - mapping.batch_dims.back().rhs = dims.rhs; - mapping.batch_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->contracting_dims) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dims.lhs; - mapping.contracting_dims.back().rhs = dims.rhs; - mapping.contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.lhs_non_contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.rhs_non_contracting_dims.back().output = dims.output; - } - auto create_sharded_conv = - [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, - spmd::SpmdBuilder* b) -> StatusOr { - TF_ASSIGN_OR_RETURN( - auto sharded_conv, - dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( - *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); - return b->AddInstruction(std::move(sharded_conv)); - }; - return HandleDotHelper(hlo, mapping, create_sharded_conv); + auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); + spmd::DotConvDimsMapping mapping; + for (const auto& dims : dims_info.batch_dims) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dims.lhs; + mapping.batch_dims.back().rhs = dims.rhs; + mapping.batch_dims.back().output = dims.output; + mapping.batch_dims.back().spatial = dims.spatial_dim; } - - auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto rhs = GetPartitionedHlo(hlo->operand(1)); - TF_ASSIGN_OR_RETURN( - auto partitioned_conv, - PartitionConvolution(lhs, rhs, hlo->shape(), hlo->sharding(), - hlo->window(), hlo, num_partitions_, options_, - partition_id_, module_, &b_)); - - if (partitioned_conv) { - SetPartitionedHlo(hlo, [&] { return partitioned_conv; }); - return Status::OK(); + for (const auto& dims : dims_info.contracting_dims) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dims.lhs; + mapping.contracting_dims.back().rhs = dims.rhs; + mapping.contracting_dims.back().output = dims.output; + mapping.contracting_dims.back().spatial = dims.spatial_dim; } - return DefaultAction(hlo); + for (const auto& dims : dims_info.lhs_non_contracting_dims) { + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.lhs_non_contracting_dims.back().output = dims.output; + mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.rhs_non_contracting_dims) { + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.rhs_non_contracting_dims.back().output = dims.output; + mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim; + } + for (const auto& dims : dims_info.conv_spatial_dims) { + mapping.conv_spatial_dims.emplace_back(); + mapping.conv_spatial_dims.back().lhs = dims.lhs; + mapping.conv_spatial_dims.back().rhs = dims.rhs; + mapping.conv_spatial_dims.back().output = dims.output; + mapping.conv_spatial_dims.back().spatial = dims.spatial_dim; + } + auto create_sharded_conv = + [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, + spmd::SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_conv, + dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( + *hlo, dims_info, lhs_hlo, rhs_hlo)); + return b->AddInstruction(std::move(sharded_conv)); + }; + return HandleDotHelper(hlo, mapping, create_sharded_conv); } } // namespace spmd diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.h b/tensorflow/compiler/xla/service/spmd/convolution_handler.h new file mode 100644 index 00000000000..dced14a4872 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ + +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Partition convolution. +StatusOr PartitionConvolution( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + const Window& conv_window, HloInstruction* original_hlo, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 55ebe120d01..25c21ba60f2 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -24,18 +24,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/numbers.h" namespace xla { namespace spmd { Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { - DotGeneralDimsMapping mapping; + DotConvDimsMapping mapping; const auto& dnums = hlo->dot_dimension_numbers(); int64 next_output_dim = 0; for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { @@ -87,8 +89,8 @@ namespace { StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions, @@ -97,11 +99,17 @@ StatusOr PartitionBaseCase( int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, int64 output_lhs_non_contracting_partitions, int64 output_rhs_non_contracting_partitions, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* - windowed_dot_general_loops) { + windowed_dot_general_loops, + bool may_reshard_without_detecting_match) { const HloSharding& lhs_sharding = lhs.sharding(); const HloSharding& rhs_sharding = rhs.sharding(); + if (lhs_sharding.ReplicateOnLastTileDim() || + rhs_sharding.ReplicateOnLastTileDim() || + output_sharding.ReplicateOnLastTileDim()) { + return nullptr; + } std::vector lhs_to_rhs_indices(lhs.base_shape().rank(), -1); std::vector lhs_to_output_indices(lhs.base_shape().rank(), -1); std::vector rhs_to_lhs_indices(rhs.base_shape().rank(), -1); @@ -109,7 +117,7 @@ StatusOr PartitionBaseCase( std::vector output_to_lhs_indices(output_base_shape.rank(), -1); std::vector output_to_rhs_indices(output_base_shape.rank(), -1); auto populate_indices_mapping = - [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + [&](const DotConvDimsMapping::DimsMapping& mapping) { if (mapping.lhs >= 0) { lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; lhs_to_output_indices[mapping.lhs] = mapping.output; @@ -135,24 +143,27 @@ StatusOr PartitionBaseCase( for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { populate_indices_mapping(mapping); } + for (const auto& mapping : dims_mapping.conv_spatial_dims) { + populate_indices_mapping(mapping); + } auto lhs_sharding_transposed_to_match_rhs = - TransposeShardingWithCollapsedDims(lhs_sharding, lhs_to_rhs_indices, - rhs_to_lhs_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); auto rhs_sharding_transposed_to_match_lhs = - TransposeShardingWithCollapsedDims(rhs_sharding, rhs_to_lhs_indices, - lhs_to_rhs_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); auto lhs_sharding_transposed_to_match_output = - TransposeShardingWithCollapsedDims(lhs_sharding, lhs_to_output_indices, - output_to_lhs_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); auto rhs_sharding_transposed_to_match_output = - TransposeShardingWithCollapsedDims(rhs_sharding, rhs_to_output_indices, - output_to_rhs_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); auto output_sharding_transposed_to_match_lhs = - TransposeShardingWithCollapsedDims(output_sharding, output_to_lhs_indices, - lhs_to_output_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + output_sharding, output_to_lhs_indices, lhs_to_output_indices); auto output_sharding_transposed_to_match_rhs = - TransposeShardingWithCollapsedDims(output_sharding, output_to_rhs_indices, - rhs_to_output_indices); + hlo_sharding_util::TransposeShardingWithCollapsedDims( + output_sharding, output_to_rhs_indices, rhs_to_output_indices); // LHS and RHS are partitioned the same way and only partitioned in batch // dimensions. @@ -401,7 +412,7 @@ StatusOr PartitionBaseCase( if (output_lhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_lhs == lhs_sharding && ShapeSizeInBytes(rhs.base_shape()) >= - threshold_for_windowed_einsum_mib * 1024 * 1024) { + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (rhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, true, false); } @@ -415,7 +426,7 @@ StatusOr PartitionBaseCase( if (output_rhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_rhs == rhs_sharding && ShapeSizeInBytes(lhs.base_shape()) >= - threshold_for_windowed_einsum_mib * 1024 * 1024) { + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { if (lhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, true, false); } @@ -485,29 +496,36 @@ StatusOr PartitionBaseCase( return dot; } - // Output is batch partitioned. - if (output_batch_partitions == num_partitions) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - resharded_rhs.hlo(), b)); - return dot; - } - // Output is partitioned along LHS non-contracting dimensions. - if (output_lhs_non_contracting_partitions == num_partitions) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - replicated_rhs.hlo(), b)); - return dot; - } - // Output is partitioned along RHS non-contracting dimensions. - if (output_rhs_non_contracting_partitions == num_partitions) { - auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), - resharded_rhs.hlo(), b)); - return dot; + if (may_reshard_without_detecting_match) { + // Output is batch partitioned. + if (output_batch_partitions == num_partitions) { + auto resharded_lhs = + lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = + rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions) { + auto resharded_lhs = + lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = + rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } } // Returns true if it is beneficial to reshard the operand at `operand_idx` @@ -558,27 +576,35 @@ StatusOr PartitionBaseCase( StatusOr PartitionDot( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops); StatusOr PartitionDotGroupOnBatch( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, - int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, - int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, int64 lhs_contracting_partitions, + int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions, + int64 rhs_non_contracting_partitions, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + bool require_matching_devices_to_group, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { + std::vector> + top_level_sharding_to_reset; + auto cleaner = tensorflow::gtl::MakeCleanup([&] { + for (auto& to_reset : top_level_sharding_to_reset) { + to_reset.first->set_sharding(to_reset.second); + } + }); std::vector lhs_dims; std::vector rhs_dims; std::vector output_dims; @@ -608,16 +634,20 @@ StatusOr PartitionDotGroupOnBatch( output_sharding_dims_adjusted_to_lhs[dim.output] = lhs.sharding().tile_assignment().dim(dim.lhs); } + if (require_matching_devices_to_group && lhs_rhs_dims_matching) { + lhs_rhs_dims_matching = + rhs.sharding() == UngroupSharding(AlignGroupsWith( + GroupShardingOnDims(rhs.sharding(), rhs_dims), + GroupShardingOnDims(lhs.sharding(), lhs_dims))); + } auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); PartitionedHlo per_group_lhs = lhs; PartitionedHlo per_group_rhs = rhs; - auto lhs_sharding = lhs.sharding(); - auto rhs_sharding = rhs.sharding(); if (lhs_rhs_dims_matching) { auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); - if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > - ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) > + ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) { rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); } else { @@ -627,12 +657,17 @@ StatusOr PartitionDotGroupOnBatch( auto reshaped_output_tiling = output_sharding.tile_assignment(); reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); output_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), - output_dims), + GroupShardingOnDims( + output_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(reshaped_output_tiling) + : HloSharding::Tile(reshaped_output_tiling), + output_dims), lhs_grouped); auto per_group_partitioner_state = CreatePerGroupPartitioningState( lhs.state(), lhs_grouped.device_groups, b); + top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding()); lhs.hlo()->set_sharding(lhs_grouped.sharding); + top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding()); rhs.hlo()->set_sharding(rhs_grouped.sharding); CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding); @@ -654,9 +689,9 @@ StatusOr PartitionDotGroupOnBatch( int64 other_contracting_dim_partitions, std::vector* sharding_dims_adjusted_to_output) -> absl::optional { - if (operand.sharding().IsReplicated()) { + if (operand.sharding().IsTileMaximal()) { auto partially_sharded = PerGroupSliceFromReplicated( - operand.hlo(), operand.state().partition_id, + operand.Replicate().hlo(), operand.state().partition_id, output_grouped.device_groups, batch_dims, output_grouped.group_dim_sizes, b); partially_sharded->set_sharding(HloSharding::Replicate()); @@ -678,9 +713,16 @@ StatusOr PartitionDotGroupOnBatch( } int64 ratio = Product(*sharding_dims_adjusted_to_output) / reshaped_tiling.num_elements(); - if (ratio == non_contracting_dim_partitions && - (ratio != contracting_dim_partitions || - contracting_dim_partitions == other_contracting_dim_partitions)) { + if (operand.sharding().ReplicateOnLastTileDim() && + reshaped_tiling.dimensions().back() % ratio == 0) { + sharding_dims_adjusted_to_output->back() /= ratio; + if (sharding_dims_adjusted_to_output->back() == 1) { + sharding_dims_adjusted_to_output->pop_back(); + } + } else if (ratio == non_contracting_dim_partitions && + (ratio != contracting_dim_partitions || + contracting_dim_partitions == + other_contracting_dim_partitions)) { for (int64 dim : non_contracting_dims) { (*sharding_dims_adjusted_to_output)[dim] = 1; } @@ -688,6 +730,8 @@ StatusOr PartitionDotGroupOnBatch( for (int64 dim : contracting_dims) { (*sharding_dims_adjusted_to_output)[dim] = 1; } + } else { + return absl::nullopt; } } // If the operand is initially sharded more ways than the output in the @@ -699,9 +743,19 @@ StatusOr PartitionDotGroupOnBatch( } reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output); auto grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_tiling), batch_dims), + GroupShardingOnDims(operand.base_shape().rank() < + sharding_dims_adjusted_to_output->size() + ? HloSharding::PartialTile(reshaped_tiling) + : HloSharding::Tile(reshaped_tiling), + batch_dims), output_grouped); + if (require_matching_devices_to_group && + operand.sharding() != UngroupSharding(grouped)) { + return absl::nullopt; + } auto resharded = operand.Reshard(UngroupSharding(grouped)); + top_level_sharding_to_reset.emplace_back(resharded.hlo(), + resharded.sharding()); resharded.hlo()->set_sharding(grouped.sharding); return PartitionedHlo(resharded.hlo(), GetPerGroupBaseShape(grouped, operand.base_shape()), @@ -754,12 +808,8 @@ StatusOr PartitionDotGroupOnBatch( GetPerGroupBaseShape(output_grouped, output_base_shape), output_grouped.sharding, dims_mapping, num_partitions / output_grouped.device_groups.size(), - create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, + create_sharded_dot, module, original_hlo, options, b, windowed_dot_general_loops)); - // Make sure the operands' sharding are set to the ungrouped ones. - lhs.hlo()->set_sharding(lhs_sharding); - rhs.hlo()->set_sharding(rhs_sharding); dot->set_sharding(UngroupSharding(output_grouped)); return PartitionedHlo(dot, output_base_shape, lhs.state()) .Reshard(output_sharding) @@ -769,65 +819,96 @@ StatusOr PartitionDotGroupOnBatch( StatusOr PartitionDotGroupOnNonContracting( bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, int64 matching_contracting_partitions, int64 other_contracting_partitions, - int64 matching_non_contracting_partitions, + absl::Span + partitioned_non_contractin_dims, int64 other_non_contracting_partitions, int64 output_other_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const DotConvDimsMapping& dims_mapping, int64 num_partitions, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + bool require_matching_devices_to_group, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { - const bool may_replicate_other_contracting_dims = - (other_contracting_partitions == matching_non_contracting_partitions && - other_non_contracting_partitions == - output_other_non_contracting_partitions); - const bool may_replicate_other_non_contracting_dims = - matching_non_contracting_partitions == other_non_contracting_partitions && - matching_contracting_partitions == other_contracting_partitions; - std::vector other_group_dims; - if (may_replicate_other_contracting_dims && - (!may_replicate_other_non_contracting_dims || - ShapeUtil::ByteSizeOf(other.base_shape()) <= - ShapeUtil::ByteSizeOf(output_base_shape))) { - for (const auto& dim : dims_mapping.contracting_dims) { - other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + std::vector> + top_level_sharding_to_reset; + auto cleaner = tensorflow::gtl::MakeCleanup([&] { + for (auto& to_reset : top_level_sharding_to_reset) { + to_reset.first->set_sharding(to_reset.second); } - } else if (may_replicate_other_non_contracting_dims) { - for (const auto& dim : lhs_matching - ? dims_mapping.rhs_non_contracting_dims - : dims_mapping.lhs_non_contracting_dims) { - other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); - } - } else if (!other.sharding().IsReplicated()) { - return nullptr; - } + }); + auto matching_sharding_dims = matching.sharding().tile_assignment().dimensions(); std::vector matching_dims; std::vector output_dims; + int64 group_count = 1; // Make sure the partitioning on matching's non-contracting dimensions // defines the same device groups for both matching and output. - for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims - : dims_mapping.rhs_non_contracting_dims) { + for (const auto& dim : partitioned_non_contractin_dims) { int64 md = lhs_matching ? dim.lhs : dim.rhs; matching_sharding_dims[md] = output_sharding.tile_assignment().dim(dim.output); matching_dims.push_back(md); output_dims.push_back(dim.output); + group_count *= output_sharding.tile_assignment().dim(dim.output); } auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); auto reshaped_matching_tiling = matching.sharding().tile_assignment(); reshaped_matching_tiling.Reshape(matching_sharding_dims); auto matching_grouped = AlignGroupsWith( - GroupShardingOnDims(HloSharding::Tile(reshaped_matching_tiling), - matching_dims), + GroupShardingOnDims( + matching.sharding().ReplicateOnLastTileDim() + ? HloSharding::PartialTile(reshaped_matching_tiling) + : HloSharding::Tile(reshaped_matching_tiling), + matching_dims), output_grouped); + if (require_matching_devices_to_group && + matching.sharding() != UngroupSharding(matching_grouped)) { + return nullptr; + } + + std::vector other_group_dims; + if (other.sharding().ReplicateOnLastTileDim() && + other.sharding().tile_assignment().dimensions().back() % group_count == + 0) { + other_group_dims.push_back(other.base_shape().rank()); + } else { + const bool may_replicate_other_contracting_dims = + (other_contracting_partitions == group_count && + other_non_contracting_partitions == + output_other_non_contracting_partitions); + const bool may_replicate_other_non_contracting_dims = + group_count == other_non_contracting_partitions && + matching_contracting_partitions == other_contracting_partitions; + if (auto found_dims = FindMatchingPartitionedDimsForGrouping( + other.sharding(), output_grouped.device_groups)) { + other_group_dims = std::move(*found_dims); + } else if (may_replicate_other_contracting_dims && + (!may_replicate_other_non_contracting_dims || + ShapeUtil::ByteSizeOf(other.hlo()->shape()) <= + ShapeUtil::ByteSizeOf(MakePartitionedShape( + output_base_shape, output_sharding)))) { + for (const auto& dim : dims_mapping.contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else if (may_replicate_other_non_contracting_dims) { + for (const auto& dim : lhs_matching + ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else { + other = other.Replicate(); + } + } + matching = matching.Reshard(UngroupSharding(matching_grouped)); auto per_group_partitioner_state = CreatePerGroupPartitioningState( matching.state(), matching_grouped.device_groups, b); + top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding()); matching.hlo()->set_sharding(matching_grouped.sharding); auto matching_p = PartitionedHlo( matching.hlo(), @@ -835,13 +916,31 @@ StatusOr PartitionDotGroupOnNonContracting( per_group_partitioner_state); auto partially_replicated_other = other.hlo(); - if (!other.sharding().IsReplicated()) { + if (other_group_dims.size() == 1 && + other_group_dims[0] == other.base_shape().rank()) { + // Group on replication dim. + auto grouped = AlignGroupsWith( + GroupShardingOnDims( + other.sharding(), {other_group_dims[0]}, + {other.sharding().tile_assignment().dimensions().back() / + group_count}), + output_grouped); + other = other.Reshard(UngroupSharding(grouped)); + partially_replicated_other = other.hlo(); + top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding()); + partially_replicated_other->set_sharding(grouped.sharding); + } else if (!other.sharding().IsReplicated()) { auto other_grouped = AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims), output_grouped, /*ignore_group_order=*/true); other = other.Reshard(UngroupSharding(other_grouped)); partially_replicated_other = - other.ReplicatePartial(other_grouped.group_dims); + other + .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + other.sharding(), other_grouped.group_dims)) + .hlo(); + top_level_sharding_to_reset.emplace_back( + partially_replicated_other, partially_replicated_other->sharding()); partially_replicated_other->set_sharding(other_grouped.sharding); } auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), @@ -853,31 +952,188 @@ StatusOr PartitionDotGroupOnNonContracting( GetPerGroupBaseShape(output_grouped, output_base_shape), output_grouped.sharding, dims_mapping, num_partitions / matching_grouped.device_groups.size(), - create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, + create_sharded_dot, module, original_hlo, options, b, windowed_dot_general_loops)); - // Reset matching's sharding to the ungrouped one. - matching.hlo()->set_sharding(UngroupSharding(matching_grouped)); return dot; } +StatusOr PartitionDotGroupOnContracting( + PartitionedHlo lhs, PartitionedHlo rhs, + absl::Span + partitioned_contractin_dims, + int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions, + int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + bool require_matching_devices_to_group, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + std::vector> + top_level_sharding_to_reset; + auto cleaner = tensorflow::gtl::MakeCleanup([&] { + for (auto& to_reset : top_level_sharding_to_reset) { + to_reset.first->set_sharding(to_reset.second); + } + }); + auto lhs_sharding = lhs.sharding(); + auto rhs_sharding = rhs.sharding(); + auto lhs_tile_shape = lhs_sharding.tile_assignment().dimensions(); + auto rhs_tile_shape = rhs_sharding.tile_assignment().dimensions(); + std::vector lhs_dims; + std::vector rhs_dims; + int64 group_count = 1; + for (const auto& dim : partitioned_contractin_dims) { + lhs_dims.push_back(dim.lhs); + rhs_dims.push_back(dim.rhs); + group_count *= lhs_sharding.tile_assignment().dim(dim.lhs); + } + if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) > + ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) { + for (const auto& dim : partitioned_contractin_dims) { + rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs]; + } + auto new_tile = rhs.sharding().tile_assignment(); + new_tile.Reshape(rhs_tile_shape); + rhs_sharding = rhs_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile) + : HloSharding::Tile(new_tile); + } else { + for (const auto& dim : partitioned_contractin_dims) { + lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs]; + } + auto new_tile = lhs.sharding().tile_assignment(); + new_tile.Reshape(lhs_tile_shape); + lhs_sharding = lhs_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile) + : HloSharding::Tile(new_tile); + } + auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims); + auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims); + if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) > + ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) { + rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped); + rhs_sharding = UngroupSharding(rhs_grouped); + if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) { + return nullptr; + } + rhs = rhs.Reshard(rhs_sharding); + } else { + lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped); + lhs_sharding = UngroupSharding(lhs_grouped); + if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) { + return nullptr; + } + lhs = lhs.Reshard(lhs_sharding); + } + top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); + lhs.hlo()->set_sharding(lhs_grouped.sharding); + top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); + rhs.hlo()->set_sharding(rhs_grouped.sharding); + + HloSharding inner_output_sharding = HloSharding::Replicate(); + HloSharding outer_output_tmp_sharding = HloSharding::Replicate(); + if (output_sharding.ReplicateOnLastTileDim() && + output_sharding.tile_assignment().dimensions().back() % group_count == + 0) { + auto grouped = AlignGroupsWith( + GroupShardingOnDims( + output_sharding, + {output_sharding.tile_assignment().num_dimensions() - 1}, + {output_sharding.tile_assignment().dimensions().back() / + group_count}), + lhs_grouped); + outer_output_tmp_sharding = UngroupSharding(grouped); + inner_output_sharding = std::move(grouped.sharding); + } else { + std::vector group_dims; + if (auto found_dims = FindMatchingPartitionedDimsForGrouping( + output_sharding, lhs_grouped.device_groups)) { + group_dims = std::move(*found_dims); + } else if (output_lhs_non_contracting_partitions == group_count || + output_rhs_non_contracting_partitions == group_count || + output_batch_partitions == group_count) { + if (output_lhs_non_contracting_partitions == group_count) { + for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { + group_dims.push_back(dim.output); + } + } else if (output_rhs_non_contracting_partitions == group_count) { + for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { + group_dims.push_back(dim.output); + } + } else { + for (const auto& dim : dims_mapping.batch_dims) { + group_dims.push_back(dim.output); + } + } + } + if (!group_dims.empty()) { + auto grouped = AlignGroupsWith( + GroupShardingOnDims(output_sharding, group_dims), lhs_grouped); + inner_output_sharding = grouped.sharding; + outer_output_tmp_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + UngroupSharding(grouped), group_dims); + } + } + auto inner_state = CreatePerGroupPartitioningState( + lhs.state(), lhs_grouped.device_groups, b); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot( + PartitionedHlo(lhs.hlo(), + GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), + inner_state), + PartitionedHlo(rhs.hlo(), + GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), + inner_state), + MakePartitionedShape(output_base_shape, outer_output_tmp_sharding), + inner_output_sharding, dims_mapping, num_partitions / group_count, + create_sharded_dot, module, original_hlo, options, b, + windowed_dot_general_loops)); + if (!dot) { + return nullptr; + } + std::vector other_lhs_dims; + for (int64 i = 0; i < lhs_sharding.tile_assignment().num_dimensions(); ++i) { + if (!absl::c_linear_search(lhs_dims, i)) { + other_lhs_dims.push_back(i); + } + } + auto inverse_grouped = GroupShardingOnDims(lhs_sharding, other_lhs_dims); + auto ar = + CreatePerGroupPartitioningState(lhs.state(), + inverse_grouped.device_groups, b) + .collective_ops_creator.create_cross_partition_all_reduce( + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), + {}, (*lhs.state().next_channel_id)++); + ar->set_sharding(outer_output_tmp_sharding); + return PartitionedHlo(ar, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + // Recursive partitioning function. If there are partial dimensions matching in // the operands and output, group the devices and recursively partition the // in-group dot. StatusOr PartitionDot( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, HloModule* module, HloInstruction* original_hlo, - int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + bool require_matching_devices_to_group, + const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops) { // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. auto get_partitions_for_dims = [&](const HloSharding& sharding, - absl::Span dims, + absl::Span dims, int lhs_rhs_or_output) { int64 partitions = 1; if (sharding.IsTileMaximal()) { @@ -913,6 +1169,52 @@ StatusOr PartitionDot( output_sharding, dims_mapping.lhs_non_contracting_dims, 2); const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( output_sharding, dims_mapping.rhs_non_contracting_dims, 2); + const int64 lhs_conv_spatial_partitions = get_partitions_for_dims( + lhs.sharding(), dims_mapping.conv_spatial_dims, 0); + const int64 rhs_conv_spatial_partitions = get_partitions_for_dims( + rhs.sharding(), dims_mapping.conv_spatial_dims, 1); + const int64 output_conv_spatial_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.conv_spatial_dims, 2); + // Before we find partial matches along the dimensions, invoke base case again + // without may_reshard_without_detecting_match. + + // Try partition the purely spatially-partitioned convolution with convolution + // spatial dimension partitioned or depthwise parallel dimension partitioned. + if (!dims_mapping.conv_spatial_dims.empty() && + (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 || + output_conv_spatial_partitions > 1 || + original_hlo->batch_group_count() > 1 || + original_hlo->feature_group_count() > 1)) { + const auto& conv_dnums = original_hlo->convolution_dimension_numbers(); + auto window = original_hlo->window(); + + // TODO(wangtao): remove this hack by passing create_sharded_conv to + // PartitionConv. + // Update convolution window when it is in the recursive call for + // batch_dims. + if (original_hlo->batch_group_count() == 1 && + original_hlo->feature_group_count() == 1 && + !ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) { + for (const auto& dim : dims_mapping.batch_dims) { + auto wd = window.mutable_dimensions(dim.spatial); + wd->set_size(lhs.hlo()->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial))); + wd->set_stride(std::max(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + } + + TF_ASSIGN_OR_RETURN( + auto partitioned_conv, + PartitionConvolution(lhs, rhs, output_base_shape, output_sharding, + dims_mapping, window, original_hlo, num_partitions, + options, lhs.state().partition_id, module, b)); + + if (partitioned_conv) { + return partitioned_conv; + } + } + TF_ASSIGN_OR_RETURN( auto try_partitioned_dot, PartitionBaseCase( @@ -922,8 +1224,9 @@ StatusOr PartitionDot( lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, rhs_non_contracting_partitions, output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + output_rhs_non_contracting_partitions, options, b, + windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/false)); if (try_partitioned_dot) { return try_partitioned_dot; } @@ -941,7 +1244,7 @@ StatusOr PartitionDot( num_partitions, lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, rhs_non_contracting_partitions, create_sharded_dot, module, - original_hlo, threshold_for_windowed_einsum_mib, b, + original_hlo, require_matching_devices_to_group, options, b, windowed_dot_general_loops)); if (dot) { return dot; @@ -974,19 +1277,180 @@ StatusOr PartitionDot( : rhs_contracting_partitions, lhs_matching ? rhs_contracting_partitions : lhs_contracting_partitions, - lhs_matching ? lhs_non_contracting_partitions - : rhs_non_contracting_partitions, + lhs_matching ? dims_mapping.lhs_non_contracting_dims + : dims_mapping.rhs_non_contracting_dims, lhs_matching ? rhs_non_contracting_partitions : lhs_non_contracting_partitions, lhs_matching ? output_rhs_non_contracting_partitions : output_lhs_non_contracting_partitions, output_base_shape, output_sharding, dims_mapping, num_partitions, create_sharded_dot, module, original_hlo, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); if (dot) { return dot; } } + if (lhs_non_contracting_partitions > 1 && + output_lhs_non_contracting_partitions > 1) { + // If part of LHS non-contracting dims match output, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { + int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); + if (lhs_partitions > 1 && + lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + /*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions, + rhs_contracting_partitions, matching_dims, + rhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, options, + b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } + if (rhs_non_contracting_partitions > 1 && + output_rhs_non_contracting_partitions > 1) { + // If part of RHS non-contracting dims match output, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { + int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs); + if (rhs_partitions > 1 && + rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + /*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions, + lhs_contracting_partitions, matching_dims, + lhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, options, + b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } + + // Case 3: Group partitions by contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnContracting( + lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) { + // If part of contracting dims match, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.contracting_dims) { + int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); + if (lhs_partitions > 1 && + lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnContracting( + lhs, rhs, matching_dims, output_batch_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, options, + b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } + + // Case 4: If operands are replicated but output is partially replicated, + // recursive call with partial replication removed. + if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() && + output_sharding.ReplicateOnLastTileDim()) { + auto grouped_output = + GroupShardingOnDims(output_sharding, {output_base_shape.rank()}); + auto inner_state = CreatePerGroupPartitioningState( + lhs.state(), grouped_output.device_groups, b); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state), + PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state), + output_base_shape, grouped_output.sharding, dims_mapping, + output_sharding.NumTiles(), create_sharded_dot, module, + original_hlo, options, b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // We failed to find partial matches, invoke base case again with + // may_reshard_without_detecting_match. + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionBaseCase( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, + lhs_contracting_partitions, rhs_contracting_partitions, + lhs_non_contracting_partitions, rhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, options, b, + windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/true)); + if (dot) { + return dot; + } + return nullptr; +} + +StatusOr PartitionDot( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, + int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + // First try partitioning without resharding the groups, then try allow + // resharding the groups. + for (bool require_matching_devices_to_group : {true, false}) { + TF_ASSIGN_OR_RETURN( + auto try_partition, + PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); + if (try_partition) { + return try_partition; + } + } // Default action. TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(), @@ -1000,7 +1464,7 @@ StatusOr PartitionDot( } // namespace Status SpmdPartitioningVisitor::HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + HloInstruction* hlo, const DotConvDimsMapping& dims_mapping, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { auto& lhs = GetPartitionedHlo(hlo->operand(0)); @@ -1008,9 +1472,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper( TF_ASSIGN_OR_RETURN( auto partitioned_dot, PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping, - num_partitions_, create_sharded_dot, module_, hlo, - options_.threshold_for_windowed_einsum_mib, &b_, - &windowed_dot_general_loops_)); + num_partitions_, create_sharded_dot, module_, hlo, options_, + &b_, &windowed_dot_general_loops_)); SetPartitionedHlo(hlo, [&] { return partitioned_dot; }); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc new file mode 100644 index 00000000000..bdc96afba88 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h" + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace { + +HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo, + bool for_replicas) { + auto coll = DynCast(hlo); + if (!coll) { + return nullptr; + } + if (coll->constrain_layout()) { + return nullptr; + } + if (for_replicas == coll->channel_id().has_value()) { + return nullptr; + } + if (coll->opcode() == HloOpcode::kAllGather) { + return coll; + } + // Consider broadcast -> dynamic-update-slice -> all-reduce as all-gather. + if (coll->opcode() == HloOpcode::kAllReduce && coll->shape().IsArray()) { + auto operand = coll->operand(0); + return operand->opcode() == HloOpcode::kDynamicUpdateSlice && + operand->operand(0)->opcode() == HloOpcode::kBroadcast + ? coll + : nullptr; + } + return nullptr; +} + +StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, + int64 distance_threshold) { + // We consider estimate the live ranges of all-gathers by comparing their + // users' distance to the root, e.g., height. + absl::flat_hash_map height; + auto ordered_hlos = comp->MakeInstructionPostOrder(); + int64 max_height = 0; + for (auto it = ordered_hlos.rbegin(); it != ordered_hlos.rend(); ++it) { + auto hlo = *it; + int64 h = 0; + for (auto user : hlo->users()) { + h = std::max(h, height[user]) + 1; + } + max_height = std::max(max_height, h); + height[hlo] = h; + } + + auto lowest_user_height = [&](const HloInstruction* hlo) { + int64 lowest = height[hlo]; + for (auto user : hlo->users()) { + lowest = std::min(lowest, height[user]); + } + return lowest; + }; + + absl::flat_hash_map> + operand_to_ag; + bool changed = false; + for (auto hlo : ordered_hlos) { + auto ag = MayConsiderAsAllGather(hlo, for_replicas); + if (!ag) { + continue; + } + + auto& earlier_ags = operand_to_ag[ag->operand(0)]; + bool found = false; + int64 ag_height = height[ag]; + for (auto& eag : earlier_ags) { + auto old_channel_id = ag->channel_id(); + if (eag->channel_id() && ag->channel_id()) { + ag->set_channel_id(eag->channel_id()); + } + if (!eag->Identical(*ag)) { + ag->set_channel_id(old_channel_id); + continue; + } + found = true; + ag->set_channel_id(old_channel_id); + if (lowest_user_height(eag) > ag_height + distance_threshold) { + eag = ag; + continue; + } + changed = true; + VLOG(1) << "Replacing " << ag->ToString() << " with " << eag->ToString(); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(eag)); + break; + } + if (!found) { + earlier_ags.push_back(ag); + } + } + return changed; +} + +} // namespace + +StatusOr ScheduleAwareAllGatherCSE::Run(HloModule* module) { + bool changed = false; + for (auto comp : module->computations()) { + TF_ASSIGN_OR_RETURN( + auto comp_changed, + RunOnComputation(comp, for_replicas_, distance_threshold_)); + changed |= comp_changed; + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h new file mode 100644 index 00000000000..4653286ae97 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Performs CSE for all-gather if their users are within reasonable live range. +class ScheduleAwareAllGatherCSE : public HloModulePass { + public: + // distance_threshold: maximum live range (in number of HLO instructions on + // the path) to consider CSE. + // for_replicas: specifies if this pass is for cross-replica or + // cross-partition all-gathers. + explicit ScheduleAwareAllGatherCSE(int64 distance_threshold, + bool for_replicas) + : distance_threshold_(distance_threshold), for_replicas_(for_replicas) {} + + ~ScheduleAwareAllGatherCSE() override = default; + absl::string_view name() const override { + return "schedule-aware-all-gather-cse"; + } + + StatusOr Run(HloModule* module) override; + + private: + int64 distance_threshold_; + bool for_replicas_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 2d76966a494..ceb81330639 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -216,20 +217,147 @@ HloInstruction* SpmdBuilder::AddInstruction( if (visiting_hlo_) { instructions_[visiting_hlo_].push_back(hlo); } + if (hlo->opcode() == HloOpcode::kBroadcast) { + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i)) { + broadcast_dims_[hlo].insert(i); + } + } + } + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + absl::flat_hash_set broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + broadcast_dims.insert(i); + } + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto it = broadcast_dims_.find(hlo->operand(i)); + if (it == broadcast_dims_.end()) { + broadcast_dims.clear(); + break; + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!it->second.contains(i)) { + broadcast_dims.erase(i); + } + } + } + if (!broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(broadcast_dims); + } + } + if (hlo->opcode() == HloOpcode::kTranspose) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set xpose_broadcast_dims; + std::vector reverse_map(hlo->shape().rank()); + for (int64 i = 0; i < reverse_map.size(); ++i) { + reverse_map[hlo->dimensions(i)] = i; + } + for (int64 dim : it->second) { + xpose_broadcast_dims.insert(reverse_map[dim]); + } + broadcast_dims_[hlo] = std::move(xpose_broadcast_dims); + } + } + if (hlo->opcode() == HloOpcode::kReshape && + Product(hlo->shape().dimensions()) > 0) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set reshape_broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + reshape_broadcast_dims.insert(i); + } + std::vector before_dim_size_stack; + std::vector after_dim_size_stack; + for (int64 i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) { + before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i)); + } + for (int64 i = hlo->shape().rank() - 1; i >= 0; --i) { + after_dim_size_stack.push_back(hlo->shape().dimensions(i)); + } + while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) { + int64 before_size = before_dim_size_stack.back(); + int64 after_size = after_dim_size_stack.back(); + int64 current_before_dim = + hlo->operand(0)->shape().rank() - before_dim_size_stack.size(); + int64 current_after_dim = + hlo->shape().rank() - after_dim_size_stack.size(); + before_dim_size_stack.pop_back(); + after_dim_size_stack.pop_back(); + if (!it->second.contains(current_before_dim)) { + reshape_broadcast_dims.erase(current_after_dim); + } + if (before_size == after_size) { + continue; + } + if (before_size % after_size == 0) { + // Split dim. + before_dim_size_stack.push_back(before_size / after_size); + } else if (after_size % before_size == 0) { + // Merge dim. + after_dim_size_stack.push_back(after_size / before_size); + } else { + // Other cases, mark all remaining dims as non-broadcast. + for (int64 i = current_after_dim; i < hlo->shape().rank(); ++i) { + reshape_broadcast_dims.erase(i); + } + break; + } + } + if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) { + reshape_broadcast_dims.clear(); + } + if (!reshape_broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(reshape_broadcast_dims); + } + } + } + if (hlo->opcode() == HloOpcode::kSlice || + hlo->opcode() == HloOpcode::kDynamicSlice) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + auto dims = it->second; + broadcast_dims_[hlo] = std::move(dims); + } + } + if (hlo->opcode() == HloOpcode::kPad) { + auto it = broadcast_dims_.find(hlo->operand(0)); + if (it != broadcast_dims_.end()) { + absl::flat_hash_set pad_broadcast_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& dim = hlo->padding_config().dimensions(i); + if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 && + dim.interior_padding() == 0 && it->second.contains(i)) { + pad_broadcast_dims.insert(i); + } + } + if (!pad_broadcast_dims.empty()) { + broadcast_dims_[hlo] = std::move(pad_broadcast_dims); + } + } + } return hlo; } PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; - for (auto& entry : cache) { - if (entry.first == target) { - return entry.second; + const bool is_to_replicate = + hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles(); + if (!is_to_replicate || state_.partitioner->options().cache_all_gather) { + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } } } - cache.emplace_back(target, ReshardNoCache(target)); - state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + auto resharded = ReshardNoCache(target); + state_.reshard_cache->per_hlo_cache[resharded.hlo()] .reshard_cache.emplace_back(sharding(), *this); - return cache.back().second; + if (!is_to_replicate || state_.partitioner->options().cache_all_gather) { + cache.emplace_back(target, std::move(resharded)); + return cache.back().second; + } + return resharded; } PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { @@ -282,6 +410,20 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { return ReshardWithAllToAll(target, *src_tgt_dims); } + if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) { + auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target); + if (try_reshard.has_value()) { + return try_reshard.value(); + } + } + + if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) { + auto try_reshard = ReshardToPartialReplicateWithAllGather(target); + if (try_reshard.has_value()) { + return try_reshard.value(); + } + } + // If not replicated yet, first replicate and then reshard to use one of the // two implementations below. if (!sharding().IsReplicated()) { @@ -296,6 +438,19 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { return PartitionedHlo(copy, base_shape_, state_); } + // 'Replicated' to partial replicated. + if (target.ReplicateOnLastTileDim()) { + std::vector group_dims(target.tile_assignment().num_dimensions() - + 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto target_grouped = GroupShardingOnDims(target, group_dims); + auto partially_sharded = PerGroupSliceFromReplicated( + hlo_, state_.partition_id, target_grouped.device_groups, group_dims, + target_grouped.group_dim_sizes, state_.b); + partially_sharded->set_sharding(target); + return PartitionedHlo(partially_sharded, base_shape(), state_); + } + // 'Replicated' to 'Tiled'. auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); @@ -651,6 +806,14 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, } PartitionedHlo PartitionedHlo::Replicate() { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + if (state_.partitioner->options().cache_all_gather) { + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + } const HloSharding& sharding = hlo_->sharding(); const Shape& shape = hlo_->shape(); CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); @@ -658,7 +821,6 @@ PartitionedHlo PartitionedHlo::Replicate() { if (sharding.IsReplicated()) { return *this; } - auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; for (auto& entry : cache) { if (entry.first.IsReplicated()) { return entry.second; @@ -667,8 +829,11 @@ PartitionedHlo PartitionedHlo::Replicate() { auto update_cache = [&](PartitionedHlo resharded) { state_.reshard_cache->per_hlo_cache[resharded.hlo()] .reshard_cache.emplace_back(sharding, *this); - cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); - return cache.back().second; + if (state_.partitioner->options().cache_all_gather) { + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + } + return resharded; }; // 'Single Device' to 'Repliated'. if (sharding.IsTileMaximal()) { @@ -724,11 +889,160 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { std::vector strides(target_shape.rank(), 1); result = state_.b->AddInstruction( HloInstruction::CreateSlice(target_shape, result, start_indices, - base_shape_.dimensions(), strides)); + target_shape.dimensions(), strides)); } return result; } +absl::optional +PartitionedHlo::ReshardToPartialReplicateWithAllGather( + const HloSharding& target) { + if (!target.ReplicateOnLastTileDim()) { + return absl::nullopt; + } + // Tiled/partial replicate to partial replicate + // Get the comptible sharding to target with resharding by all reduce. + auto compatible_sharding = + PartialReplicateReshardCompatibleSharding(target, sharding()); + if (!compatible_sharding.has_value()) { + return absl::nullopt; + } + + const auto& temp_sharding = compatible_sharding.value(); + auto partitioned_hlo = *this; + // Use collective permute to adjust device assignment if needed. + if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) { + partitioned_hlo = + partitioned_hlo.ReshardWithCollectivePermute(temp_sharding); + } + + // Get replicate dims and replicate factor of each dimensions. + int64 rank = hlo_->shape().rank(); + std::vector replicate_dims; + std::vector replicate_factors; + for (int64 dim = 0; dim < rank; dim++) { + int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) / + target.tile_assignment().dim(dim); + if (replicate_factor > 1) { + replicate_dims.emplace_back(dim); + replicate_factors.emplace_back(replicate_factor); + } + } + + // Do left halo exchange if all-reduce directly will remove useful data + // from the source. + auto halo_exchange = TileToPartialReplicateHaloExchange( + partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims, + partitioned_hlo.state().collective_ops_creator, + partitioned_hlo.state().next_channel_id, + partitioned_hlo.state().partition_id, partitioned_hlo.state().b); + if (!halo_exchange.has_value()) { + return absl::nullopt; + } + auto halo_exchange_hlo = halo_exchange.value(); + // Grouped on replicate dimensions. + auto sharding_grouped = + GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + partitioned_hlo.state(), sharding_grouped.device_groups, + partitioned_hlo.state().b); + auto base_shape = MakePartitionedShape(base_shape_, target); + // It's possible that halo_exchange_hlo == hlo.hlo(). + // Record the sharding of hlo here, and reset it before return. + auto original_sharding = partitioned_hlo.sharding(); + halo_exchange_hlo->set_sharding(sharding_grouped.sharding); + auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape, + per_group_partitioner_state); + HloInstruction* result = + partial_replicate_hlo.ReplicatePartial(replicate_dims); + partitioned_hlo.hlo()->set_sharding(original_sharding); + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, partitioned_hlo.state()); +} + +absl::optional +PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( + const HloSharding& target) { + if (!sharding().ReplicateOnLastTileDim()) { + return absl::nullopt; + } + + // Get the temp sharding target from partial replicate to target tile dims. + // target_compatible_sharding has the same tile_assignment dimensions + // as the target and can reshard to target by collective permute. + // target_compatible_sharding could have different device assignment as + // targe. sharding() can reshard to target_compatible_sharding by + // dynamic slice. + auto target_compatible_sharding = + PartialReplicateReshardCompatibleSharding(sharding(), target); + // Reshard to target_compatible_sharding by dynamic slice. + if (!target_compatible_sharding.has_value()) { + return absl::nullopt; + } + std::vector expand_tile_dims; + std::vector tiling_dim_factors; + int64 rank = hlo_->shape().rank(); + tiling_dim_factors.reserve(target.tile_assignment().num_dimensions()); + const auto& temp_target_sharding = target_compatible_sharding.value(); + for (int64 dim = 0; dim < rank; dim++) { + if (temp_target_sharding.tile_assignment().dim(dim) > + sharding().tile_assignment().dim(dim)) { + expand_tile_dims.push_back(dim); + } + tiling_dim_factors.emplace_back( + temp_target_sharding.tile_assignment().dim(dim) / + sharding().tile_assignment().dim(dim)); + } + + // Add another dimension in tiling_dim_factors if target is partial replicate. + if (target.ReplicateOnLastTileDim()) { + tiling_dim_factors.emplace_back( + target.tile_assignment().dimensions().back()); + } + + // Get per_group partitioner state. + std::vector group_dims(sharding().tile_assignment().num_dimensions() - + 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + state_, sharding_grouped.device_groups, state_.b); + // 2. Get the padded_hlo, do right halo exchange if needed. + auto padded_hlo = PadFromPartialReplicateShape( + hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims, + state_.collective_ops_creator, state_.next_channel_id, + state_.partition_id, state_.b); + if (!padded_hlo.has_value()) { + return absl::nullopt; + } + // 3. Slice out the tile from replicate ones. + auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding); + // device assignment within each group is sorted in + // HloSharding::PartialTile, thus partiton_id within each group can be + // matched with the order in tile_assignment. + Array tiling_assignment(tiling_dim_factors); + tiling_assignment.FillIota(0); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo.value(), + MakePartitionOffsets(padded_hlo.value()->shape(), + target.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tiling_assignment) + : HloSharding::Tile(tiling_assignment), + per_group_partitioner_state.partition_id, + per_group_partitioner_state.b), + shard_shape.dimensions())); + slice->set_sharding(temp_target_sharding); + auto result = PartitionedHlo(slice, base_shape_, state_); + // If temp_target_sharding's device assignment is different from target, + // use collective permute to reshard. + if (CanReshardWithCollectivePermute(temp_target_sharding, target)) { + return result.ReshardWithCollectivePermute(target); + } + // If device assignment in temp_target_sharding and target are the same, + // return result directly. + return result; +} + PartitionedHlo PartitionedHlo::Broadcast() const { const Shape& shape = hlo_->shape(); const HloSharding& sharding = hlo_->sharding(); @@ -813,8 +1127,9 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( sharding().tile_assignment().dim(source_dim); temp_target_tile.Reshape(temp_target_tile_dims); } - auto temp_target = HloSharding::Tile(temp_target_tile); - + auto temp_target = target.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(temp_target_tile) + : HloSharding::Tile(temp_target_tile); auto padded_shape = hlo_->shape(); padded_shape.set_dimensions( target_dim, @@ -904,6 +1219,27 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( const HloSharding& target) const { CHECK(CanReshardWithCollectivePermute(sharding(), target)) << sharding().ToString() << " to " << target.ToString(); + if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) { + if (!(*broadcast_dims)->empty()) { + // If hlo() has broadcast dims, check if data is already the same between + // source/destination pairs. + std::vector broadcast_dims_vector; + for (int64 i = 0; i < hlo()->shape().rank(); ++i) { + if ((*broadcast_dims)->contains(i)) { + broadcast_dims_vector.push_back(i); + } + } + if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding(), broadcast_dims_vector) == + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + target, broadcast_dims_vector)) { + auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary( + hlo()->shape(), HloOpcode::kCopy, hlo())); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + } + } std::vector> src_dst_pairs; sharding().tile_assignment().Each( [&](absl::Span indices, int64 src_device) { @@ -1075,7 +1411,7 @@ namespace { // gather/scatter slice size 1. bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( const PartitionedHlo& operand, absl::Span index_map, - absl::Span slice_size, int64 num_partitions) { + absl::Span slice_size) { if (operand.sharding().IsTileMaximal()) { return false; } @@ -1086,7 +1422,7 @@ bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( operand.sharding().tile_assignment().dim(dim); } } - return trivial_slice_dims_partitions == num_partitions; + return trivial_slice_dims_partitions == operand.sharding().NumTiles(); } // Returns the min and max for the indices (replicated) in a scatter/gather @@ -1209,6 +1545,16 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { case HloOpcode::kAnd: identity = CreateOne(operand.hlo()->shape(), &b_); break; + case HloOpcode::kMinimum: + identity = CreateConstant( + operand.hlo()->shape(), + LiteralUtil::MaxValue(hlo->shape().element_type()), &b_); + break; + case HloOpcode::kMaximum: + identity = CreateConstant( + operand.hlo()->shape(), + LiteralUtil::MinValue(hlo->shape().element_type()), &b_); + break; default: return DefaultAction(hlo); } @@ -1221,14 +1567,29 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { update_dim_to_index_dim[update_scatter_dims[i]] = indices_scatter_dim; index_dim_to_update_dim[indices_scatter_dim] = update_scatter_dims[i]; } - auto new_updates_sharding = TransposeShardingWithCollapsedDims( - indices.sharding(), index_dim_to_update_dim, update_dim_to_index_dim); + auto new_updates_sharding = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + indices.sharding(), index_dim_to_update_dim, + update_dim_to_index_dim); CHECK(new_updates_sharding.has_value()); updates = updates.Reshard(*new_updates_sharding); + // Update collective_ops_creator and partition_id for partial replicate. + auto collective_ops_creator = collective_ops_creator_; + auto partition_id = partition_id_; + if (indices.sharding().ReplicateOnLastTileDim()) { + auto sharding_grouped = GroupShardingOnDims( + indices.sharding(), + {indices.sharding().tile_assignment().num_dimensions() - 1}); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + indices.state(), sharding_grouped.device_groups, &b_); + collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + partition_id = per_group_partitioner_state.partition_id; + } // To avoid accumulating the initial operand multiple times during - // all-reduce, we use zero operands for all non-zero partitions. + // all-reduce, we use identity operands for all non-zero partitions. auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeScalarShape(PRED), partition_id_)); + ShapeUtil::MakeScalarShape(PRED), partition_id)); not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::ChangeElementType(identity->shape(), PRED), not_partition_zero, {})); @@ -1239,7 +1600,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands( scatter->shape(), {select_operand, indices.hlo(), updates.hlo()})); auto all_reduce = - collective_ops_creator_.create_cross_partition_all_reduce( + collective_ops_creator.create_cross_partition_all_reduce( &b_, pscatter, scatter->to_apply(), {}, NewChannel()); all_reduce->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { @@ -1269,8 +1630,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { return Status::OK(); } if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( - operand, scatter_dims_to_operand_dims, slice_size, - num_partitions_) && + operand, scatter_dims_to_operand_dims, slice_size) && ShapeSizeInBytes(updates.base_shape()) < ShapeSizeInBytes(scatter->shape())) { // Operand is sharded on trivial slice dims (update slice size 1). We can @@ -1712,6 +2072,16 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { return Status::OK(); } + // Check if operand sharding and sharding are both tiled or partial replicate. + // If both of them are partial replicate, check num_replications are the same. + if (operand.sharding().ReplicateOnLastTileDim() != + sharding.ReplicateOnLastTileDim() || + (sharding.ReplicateOnLastTileDim() && + (operand.sharding().tile_assignment().dimensions().back() != + sharding.tile_assignment().dimensions().back()))) { + return DefaultAction(hlo); + } + // Try use halo exchange for certain split-dim/merge-dims cases. // ReshapeSharding failed in these cases probably due to uneven partitioning, // where halo exchange could help. Specifically we check the following @@ -1747,7 +2117,14 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { Array new_input_tile_assignment = sharding.tile_assignment(); new_input_tile_assignment.Reshape( operand.sharding().tile_assignment().dimensions()); - operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + auto aligned_sharding = + sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_input_tile_assignment) + : HloSharding::Tile(new_input_tile_assignment); + operand = operand.Reshard(aligned_sharding); + auto replication_count = sharding.ReplicateOnLastTileDim() + ? sharding.tile_assignment().dimensions().back() + : 1; int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); @@ -1770,7 +2147,7 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { dim->set_padding_low(0); if (i == input_sharded_dim) { dim->set_padding_high(output_shard_size * split_factor * - num_partitions_ - + num_partitions_ / replication_count - input_dim_size); } else { dim->set_padding_high(0); @@ -1808,8 +2185,8 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { tmp_reshape->set_sharding(hlo->sharding()); auto tmp_full_shape = tmp_shard_shape; tmp_full_shape.set_dimensions( - output_sharded_dim, - tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_ / replication_count); auto tmp_output = PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); @@ -1826,7 +2203,7 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { if (i == output_sharded_dim) { dim->set_padding_high(output_dim_size - tmp_shard_shape.dimensions(output_sharded_dim) * - num_partitions_); + num_partitions_ / replication_count); } else { dim->set_padding_high(0); } @@ -1951,67 +2328,22 @@ Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { auto& operand = GetPartitionedHlo(hlo->operand(0)); // Tiled output. - std::vector wanted_input_tile_size(operand.base_shape().rank()); - std::vector sharded_new_dims; - for (int64 i = 0; i < operand.base_shape().rank(); ++i) { - wanted_input_tile_size[i] = - hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); - } + std::vector new_dims; for (int64 i = 0; i < hlo->shape().rank(); ++i) { - if (!absl::c_linear_search(hlo->dimensions(), i) && - hlo->sharding().tile_assignment().dim(i) > 1) { - sharded_new_dims.push_back(i); + if (!absl::c_linear_search(hlo->dimensions(), i)) { + new_dims.push_back(i); } } - if (sharded_new_dims.empty()) { - // The new dimensions are replicated, so that we can do the adjustment on - // the input. - Array wanted_input_tile_assignment(wanted_input_tile_size); - wanted_input_tile_assignment.Each( - [&](absl::Span indices, int64* val) { - std::vector indices_in_broadcast(hlo->shape().rank(), 0); - for (int64 i = 0; i < operand.base_shape().rank(); ++i) { - indices_in_broadcast[hlo->dimensions(i)] = indices[i]; - } - *val = hlo->sharding().tile_assignment()(indices_in_broadcast); - }); - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction(hlo->CloneWithNewOperands( - MakePartitionedShape(hlo->shape(), hlo->sharding()), - {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) - .hlo()})); - }); - } else { - auto input = operand.Reshard(HloSharding::Replicate()).hlo(); - // We pad and shard the input first, then broadcast to the final shard - // shape. - auto output_offsets = - MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); - std::vector input_offsets(operand.base_shape().rank()); - auto output_shard_shape = - MakePartitionedShape(hlo->shape(), hlo->sharding()); - auto input_shard_shape = input->shape(); - auto padded_input_shape = input->shape(); - for (int64 i = 0; i < input_offsets.size(); ++i) { - input_offsets[i] = output_offsets[hlo->dimensions(i)]; - input_shard_shape.set_dimensions( - i, output_shard_shape.dimensions(hlo->dimensions(i))); - padded_input_shape.set_dimensions( - i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * - input_shard_shape.dimensions(i)); - } - auto padded_input = PadToShape(input, padded_input_shape, &b_); - auto input_shard = - ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) - ? padded_input - : b_.AddInstruction(HloInstruction::CreateDynamicSlice( - input_shard_shape, padded_input, input_offsets, - input_shard_shape.dimensions())); - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction( - hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); - }); - } + auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), + new_dims), + new_dims); + auto input = operand.Reshard(desired_input_sharding).hlo(); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input})); + }); return Status::OK(); } @@ -2134,8 +2466,10 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim; index_dim_to_output_dim[indices_batch_dim] = batch_dims[i]; } - auto pgather_sharding = TransposeShardingWithCollapsedDims( - indices.sharding(), index_dim_to_output_dim, output_dim_to_index_dim); + auto pgather_sharding = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + indices.sharding(), index_dim_to_output_dim, + output_dim_to_index_dim); CHECK(pgather_sharding.has_value()); pgather->set_sharding(*pgather_sharding); SetPartitionedHlo(hlo, [&]() { @@ -2171,8 +2505,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { return Status::OK(); } if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( - operand, start_index_map, gather->gather_slice_sizes(), - num_partitions_) && + operand, start_index_map, gather->gather_slice_sizes()) && ShapeSizeInBytes(gather->shape()) < ShapeSizeInBytes(gather->operand(0)->shape())) { indices = indices.Reshard(HloSharding::Replicate()); @@ -2234,7 +2567,17 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { pgather->shape(), HloOpcode::kSelect, broadcast_filter, CreateZero(pgather->shape(), &b_), pgather)); // Combine from different partitions. - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + auto collective_ops_creator = collective_ops_creator_; + if (operand.sharding().ReplicateOnLastTileDim()) { + auto sharding_grouped = GroupShardingOnDims( + operand.sharding(), + {operand.sharding().tile_assignment().num_dimensions() - 1}); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + operand.state(), sharding_grouped.device_groups, &b_); + collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + } + auto ar = collective_ops_creator.create_cross_partition_all_reduce( &b_, filtered, MakeBinaryAdd(filtered->shape().element_type(), module_), {}, NewChannel()); @@ -2492,7 +2835,13 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { .Reshard(HloSharding::Replicate()) .hlo()); inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); - if (operand_id > 0) { + if (hlo->shape().IsTuple() && operand_id == 0) { + // We cannot do tuple-reduce where partitioned dimensions are reduced. + // Partially replicate on those dims. + inputs[0] = inputs[0].Reshard( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + inputs[0].sharding(), hlo->dimensions())); + } else { // Make sure all operands are sharded in the same way. inputs.back() = inputs.back().Reshard(inputs[0].sharding()); } @@ -2500,28 +2849,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { inputs.back() = inputs.back().PadWithValue(inits[operand_id]); } } - bool reduce_sharded_dimension = false; - if (!inputs[0].sharding().IsTileMaximal()) { - reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { - return inputs[0].sharding().tile_assignment().dim(i) > 1; - }); - - // reduce_sharded_dimension is not supported for tuple-shaped reduces. - if (reduce_sharded_dimension && input_count > 1) { - return DefaultAction(hlo); - } - - // Currently we only support reducing all or none of the sharded - // dimensions. - if (reduce_sharded_dimension) { - for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { - if (inputs[0].sharding().tile_assignment().dim(i) > 1 && - absl::c_count(hlo->dimensions(), i) == 0) { - return DefaultAction(hlo); - } - } - } - } std::vector new_operand_shapes(input_count * 2); for (int64 i = 0; i < input_count; ++i) { @@ -2533,7 +2860,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { auto reduce_shape, ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), hlo->to_apply()->ComputeProgramShape())); - *reduce_shape.mutable_layout() = hlo->shape().layout(); std::vector input_hlos(input_count); for (int64 i = 0; i < input_count; ++i) { @@ -2544,36 +2870,35 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { local_reduce->set_metadata(hlo->metadata()); SetPartitionedHlo(hlo, [&]() { - HloInstruction* reduce; + HloInstruction* reduce = local_reduce; + const bool reduce_sharded_dimension = + !inputs[0].sharding().IsTileMaximal() && + absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); if (reduce_sharded_dimension) { CHECK(local_reduce->shape().IsArray()); - reduce = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, local_reduce, hlo->to_apply(), {}, NewChannel()); - reduce->set_sharding(HloSharding::Replicate()); - } else { - reduce = local_reduce; - if (inputs[0].sharding().IsTileMaximal()) { - reduce->set_sharding(inputs[0].sharding()); - } else { - // Remove tile assignment dimensions that are reduced. - std::vector tile_dimensions; - for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { - if (absl::c_count(hlo->dimensions(), i) == 0) { - tile_dimensions.push_back( - inputs[0].sharding().tile_assignment().dim(i)); - } + std::vector preserved_dims; + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i)) { + preserved_dims.push_back(i); } - Array new_tile = inputs[0].sharding().tile_assignment(); - new_tile.Reshape(tile_dimensions); - auto sharding = HloSharding::Tile(new_tile); - if (input_count > 1) { - std::vector tuple(input_count, sharding); - sharding = HloSharding::Tuple(hlo->shape(), tuple); - } - reduce->set_sharding(sharding); } + if (inputs[0].sharding().ReplicateOnLastTileDim()) { + preserved_dims.push_back(inputs[0].base_shape().rank()); + } + auto grouped = GroupShardingOnDims(inputs[0].sharding(), preserved_dims); + auto grouped_state = CreatePerGroupPartitioningState( + inputs[0].state(), grouped.device_groups, &b_); + reduce = grouped_state.collective_ops_creator + .create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), {}, NewChannel()); } - + auto sharding = hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + inputs[0].sharding(), hlo->dimensions()), + hlo->dimensions()); + reduce->set_sharding(sharding); return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) .Reshard(hlo->sharding()) .hlo(); @@ -2692,18 +3017,37 @@ Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { } TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); - SetPartitionedHlo(hlo, [&] { - // Replicate the operands and run partitioned Rng on all devices. - std::vector new_operands; - for (int64 i = 0; i < hlo->operand_count(); ++i) { - new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) - .Reshard(HloSharding::Replicate()) - .hlo()); - } - return b_.AddInstruction(HloInstruction::CreateRng( + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + + if (!hlo->sharding().ReplicateOnLastTileDim()) { + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + } else { + std::vector group_dims( + hlo->sharding().tile_assignment().num_dimensions() - 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims); + auto per_group_state = CreatePerGroupPartitioningState( + MakePartitioningState(), sharding_grouped.device_groups, &b_); + auto rng = b_.AddInstruction(HloInstruction::CreateRng( MakePartitionedShape(hlo->shape(), hlo->sharding()), hlo->random_distribution(), new_operands)); - }); + rng->set_sharding(HloSharding::AssignDevice(0)); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(rng, rng->shape(), per_group_state) + .Replicate() + .hlo(); + }); + } return Status::OK(); } @@ -3258,7 +3602,7 @@ StatusOr SpmdPartitioner::Run(HloModule* module) { HloPassPipeline pass("spmd-cleanup"); pass.AddPass(); pass.AddPass(); - pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(/*is_layout_sensitive=*/false); pass.AddPass(); TF_RETURN_IF_ERROR(pass.Run(module).status()); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index a612c16bdae..b09ea0c8e0b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -47,6 +48,12 @@ struct SpmdPartitionerOptions { // Whether the entry computations' signature could change after partitioning. bool allow_module_signature_change = false; + + // Whether to use cached all-gather to avoid repeatedly replicate a tiled + // tensor. If it is set to false, the result tends to be more + // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE + // pass to CSE some all-gathers which are relatively close to each other. + bool cache_all_gather = true; }; // Class to wrap the computation builder to capture information during SPMD @@ -68,6 +75,16 @@ class SpmdBuilder : public HloComputation::Builder { HloInstruction* visiting_hlo() const { return visiting_hlo_; } + // Wrapper of queries to broadcast_dims_. + absl::optional*> BroadcastDimsForCreatedHlo( + const HloInstruction* hlo) { + auto it = broadcast_dims_.find(hlo); + if (it == broadcast_dims_.end()) { + return absl::nullopt; + } + return &it->second; + } + private: // Currently visiting instruction. HloInstruction* visiting_hlo_; @@ -75,6 +92,12 @@ class SpmdBuilder : public HloComputation::Builder { // Map from the currently visiting (old) instruction to new instructions // created during SPMD partitioning. HloInstructionMap> instructions_; + + // Maps from each created instruction to a set of dimensions that are from + // broadcasts or elementwise ops over broadcasts. This means elements along + // these dimensions have the same value. + absl::flat_hash_map> + broadcast_dims_; }; // A set of functions that create the cross-partition collective ops. @@ -180,6 +203,8 @@ class SpmdPartitioner : public HloModulePass { int64 channel_id, absl::Span selected_dims, const SPMDCollectiveOpsCreator& collectives_creator); + const SpmdPartitionerOptions& options() { return options_; } + protected: virtual std::unique_ptr CreateVisitor( HloComputation* computation, int64 num_partitions, int64 num_replicas, @@ -305,6 +330,14 @@ class PartitionedHlo { // Helper function to reshard the tensor using CollectivePermute. PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + // Helper function to reshard to partial replicate using AllGather. + absl::optional ReshardToPartialReplicateWithAllGather( + const HloSharding& target); + + // Helper function to reshard from partial replicate using DynamicSlice. + absl::optional ReshardFromPartialReplicateWithDynamicSlice( + const HloSharding& target); + // SPMD instruction. HloInstruction* hlo_; @@ -314,27 +347,11 @@ class PartitionedHlo { PartitioningState state_; }; -struct DotGeneralDimsMapping { +struct DotConvDimsMapping { // The dimension numbers for the operands and output corresponding to a // logical dimension (e.g., batch, contracting, non-contracting). If an // operand or the output doesn't have the logical dimension, it is set to // -1. - struct DimsMapping { - int64 lhs; - int64 rhs; - int64 output; - }; - std::vector batch_dims; - std::vector contracting_dims; - std::vector lhs_non_contracting_dims; - std::vector rhs_non_contracting_dims; -}; - -struct ConvolutionDimsMapping { - // The dimension numbers for the operands and output corresponding to a - // logical dimension (e.g., batch, parallel, non-parallel). If an - // operand or the output doesn't have the logical dimension, it is set to - // -1. struct DimsMapping { int64 lhs; int64 rhs; @@ -342,8 +359,11 @@ struct ConvolutionDimsMapping { // input mapped to index in input_spatial_dimensions(). int64 spatial; }; - std::vector parallel_spatial_dims; - std::vector non_parallel_spatial_dims; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; + std::vector conv_spatial_dims; }; class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { @@ -388,7 +408,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // Implementation of dot partitioning given DotGeneralDimsMapping. Status HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + HloInstruction* hlo, const DotConvDimsMapping& dims_mapping, const std::function( HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index d5342e3e1f4..f3bd971df69 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -138,8 +138,7 @@ ENTRY entry { op::AllReduce(op::Select( op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), op::Constant(), op::Broadcast())), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant())), op::Shape("s32[1,3]"))); } @@ -161,8 +160,7 @@ ENTRY entry { op::Copy(op::AllReduce(AllOf( op::DynamicUpdateSlice( op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant()), op::Shape("s32[2,3]"))))); } @@ -184,8 +182,7 @@ ENTRY entry { op::Copy(op::Copy(op::AllReduce(AllOf( op::DynamicUpdateSlice( op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant()), op::Shape("s32[2,3]")))))); } @@ -279,8 +276,8 @@ ENTRY entry { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_THAT(root, op::Tuple()); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); EXPECT_THAT(root->operand(0), op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, @@ -305,13 +302,13 @@ ENTRY entry { PartitionComputation(hlo_string, /*num_devices=*/2)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( - root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( - op::Broadcast(), - op::GetTupleElement( - AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), - op::Constant())))); + root, + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), + op::Constant())))); } TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { @@ -564,6 +561,27 @@ ENTRY entry { op::Constant()))))); } +TEST_F(SpmdPartitioningTest, + BroadcastBothOldAndNewDimsShardedPartiallySharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4,3] parameter(0), + sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} + ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2}, + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,4,2]"), + op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0))))); +} + TEST_F(SpmdPartitioningTest, ConvWithParallelDimAndNonParallelSpatialDimPartitioned) { const char* const hlo_string = R"( @@ -1985,6 +2003,36 @@ ENTRY entry { EXPECT_THAT(root, op::DynamicSlice(pad, _)); } +TEST_F(SpmdPartitioningTest, PartialReplicatePad) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[11,7] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %param1 = f32[] parameter(1), sharding={replicated} + ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]")); + auto after_halo_exchange = + AllOf(op::Shape("f32[11,4]"), + op::DynamicSlice( + AllOf(op::Shape("f32[11,5]"), + op::Concatenate(op::CollectivePermute(op::Slice(param0)), + param0)), + op::Constant(), _)); + auto pad = op::Pad(after_halo_exchange, op::Parameter(1)); + EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _), + op::Shape("f32[27,11]"))); +} + TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -2042,6 +2090,61 @@ ENTRY entry { op::Shape("f32[63,14,126]"))); } +TEST_F(SpmdPartitioningTest, + PartialReplicateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %slice = f32[128,11,257] slice(%param0), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %slice = f32[63,14,251] slice(%param0), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf( + op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), + op::Add(op::Multiply(op::Reshape(op::DynamicSlice( + op::Constant(), op::PartitionId())), + op::Constant()), + op::Constant())), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -2577,6 +2680,79 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, + sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, + sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[8,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]"))); +} + TEST_F(SpmdPartitioningTest, ShardableReshape) { const char* const hlo_string = R"( HloModule module @@ -2600,6 +2776,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + TEST_F(SpmdPartitioningTest, NonShardableReshape) { const char* const hlo_string = R"( HloModule module @@ -2652,6 +2852,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[2,3,7,10] parameter(0), + sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %reshape = s32[3,2,1,14,5] reshape(%input), + sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]")); + auto halo = op::CollectivePermute(op::Slice(reshape)); + auto exchanged = + op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); +} + // Produces an invalid module after transformation. TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { const char* const hlo_string = R"( @@ -2746,6 +2970,35 @@ ENTRY entry { AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); } +TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,4] parameter(0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0}, + to_apply=%sum, + sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())), + op::Shape("f32[2]"))); +} + TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { const char* const hlo_string = R"( HloModule module @@ -2781,6 +3034,48 @@ ENTRY %main { op::Shape("(f32[14], s32[14])"))); } +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce2) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,2]0,1,2,3} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,2]0,1,2,3} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2,2]0,1,2,3 last_tile_dim_replicate}, + {devices=[2,2]0,1,2,3 last_tile_dim_replicate}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = + AllOf(op::Shape("f32[14,10]"), + op::AllReduce(op::DynamicUpdateSlice(_, op::Parameter(0), _, _))); + auto rhs = + AllOf(op::Shape("s32[14,10]"), + op::AllReduce(op::DynamicUpdateSlice(_, op::Parameter(1), _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { const char* const hlo_string = R"( HloModule module @@ -3633,6 +3928,35 @@ ENTRY entry { op::Shape("s32[2]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0), sharding={replicated} + %rhs = s32[] parameter(1), sharding={replicated} + ROOT %rng = s32[8]{0} rng(%lhs, %rhs), + distribution=rng_uniform, + sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]")); + auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]")); + auto partition_id = + AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), + op::Shape("u32[]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(partition_id, op::Constant())), + op::Rng(lhs, rhs), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { const char* const hlo_string = R"( HloModule module @@ -3710,6 +4034,26 @@ ENTRY entry { op::Shape("f32[3,5]"))); } +TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughGather) { const char* const hlo_string = R"( HloModule module @@ -3729,6 +4073,27 @@ ENTRY entry { op::Shape("f32[8,2,2]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, + slice_sizes={1,1,8}, + sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[8,2,2]"))); +} + TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { const char* const hlo_string = R"( HloModule module @@ -3743,8 +4108,39 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + +TEST_F(SpmdPartitioningTest, + GatherPartitionedOnTrivialSliceDims_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); @@ -3788,6 +4184,39 @@ ENTRY entry { op::Shape("f32[2,5]"))); } +TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) { const char* const hlo_string = R"( HloModule module @@ -3822,6 +4251,76 @@ ENTRY entry { op::Shape("f32[2,9,8]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %updates = f32[4,4,8] parameter(2), + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::Reshape())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3} + %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=min, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::PartitionId())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { const char* const hlo_string = R"( HloModule module @@ -3846,8 +4345,45 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + +TEST_F(SpmdPartitioningTest, + ScatterPartitionedOnTrivialSliceDims_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto indices = op::Subtract( op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -4035,7 +4571,7 @@ HloModule module ENTRY entry { %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3} - %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3} ROOT %dot = f32[48,32] dot(%lhs, %rhs), lhs_batch_dims={}, rhs_batch_dims={}, lhs_contracting_dims={1}, rhs_contracting_dims={1}, @@ -4052,8 +4588,8 @@ ENTRY entry { op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _))); auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1)); auto partial_replicated_rhs = - AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice( - _, op::CollectivePermute(rhs), _, _))); + AllOf(op::Shape("f32[16,12]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs), @@ -4264,6 +4800,1099 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose))); } +TEST_F(SpmdPartitioningTest, SimpleDotPartial) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[2,24,100] parameter(0), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[2,32,100] parameter(1), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %dot = f32[2,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[1,24,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[1,32,100]"), op::Parameter(1)); + auto dot = AllOf(op::Shape("f32[1,24,32]"), op::Dot(lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot); +} + +TEST_F(SpmdPartitioningTest, DotPartialContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,100] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[32,100] parameter(1), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1)); + auto dot = AllOf(op::Shape("f32[24,32]"), op::Dot(lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_F(SpmdPartitioningTest, DotPartialContracting2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,100] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[32,100] parameter(1), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1)); + auto dot = + AllOf(op::Shape("f32[12,32]"), + op::Dot(AllOf(op::Shape("f32[12,50]"), op::DynamicSlice(lhs, _, _)), + rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_F(SpmdPartitioningTest, DotPartialContracting3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,100] parameter(0), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %rhs = f32[32,100] parameter(1), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = + AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _)); + auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot))); +} + +TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7} + %rhs = f32[4,32,100] parameter(1), + sharding={devices=[2,1,2,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,12,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,32,50]"), op::Parameter(1)); + auto dot = AllOf(op::Shape("f32[2,12,32]"), op::Dot(lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_F(SpmdPartitioningTest, DotPartialNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,2,1,3} + ROOT %dot = f32[24,8,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + sharding={devices=[2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,8,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1)); + auto partially_replicated_rhs = + AllOf(op::Shape("f32[16,100]"), + op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), rhs, _, _))); + auto dot = + AllOf(op::Shape("f32[12,8,16]"), op::Dot(lhs, partially_replicated_rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot); +} + +TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[32,100] parameter(1), + sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,8,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + sharding={devices=[2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1)); + auto partially_replicated_lhs = AllOf( + op::Shape("f32[12,8,100]"), + op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _))); + auto dot = + AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot); +} + +TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3} + %rhs = f32[32,8,100] parameter(1), + sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1)); + auto dot = AllOf(op::Shape("f32[24,32]"), + op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"), + op::DynamicSlice(rhs, _, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot))); +} + +TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3} + ROOT %dot = f32[24,8,50] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1)); + auto dot = AllOf( + op::Shape("f32[12,8,50]"), + op::Dot(lhs, AllOf(op::Shape("f32[50,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"), + op::DynamicSlice(op::AllReduce(dot), _, _, _))) + << module->ToString(); +} + +TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[10,50] parameter(1), + sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,8,50] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1)); + auto dot = AllOf( + op::Shape("f32[12,4,50]"), + op::Dot(lhs, AllOf(op::Shape("f32[10,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot) << module->ToString(); +} + +TEST_F(SpmdPartitioningTest, + ElementwiseTest_PartialReplicateToTiledHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[6,3]{1,0} + constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}), + sharding={replicated} + constant.1 = f32[6,3]{1,0} + constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}), + sharding={replicated} + multiply = f32[6,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT add = f32[6,3]{1,0} add(multiply, constant.1), + sharding={devices=[4,1]0,1,2,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto partial_replicate_lhs = + AllOf(op::Shape("f32[3,3]"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto partial_replicate_rhs = + AllOf(op::Shape("f32[3,3]"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto multiply = + AllOf(op::Shape("f32[3,3]"), + op::Multiply(partial_replicate_lhs, partial_replicate_rhs)); + auto right_halo = + AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply))); + auto add_lhs = AllOf( + op::Shape("f32[2,3]"), + op::DynamicSlice( + op::DynamicSlice( + op::Pad(op::Concatenate(multiply, right_halo), op::Constant()), + op::Reshape(), op::Constant()), + op::Reshape(), op::Constant())); + auto add_rhs = AllOf(op::Shape("f32[2,3]"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs))); +} + +TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,2]0,1,2,3} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Reshape()))); + auto partially_replicated = AllOf( + op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), tiled, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Constant(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_AllReduce) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(param0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Reshape()))); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), partially_replicated_init, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_DynamicSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Constant(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_DynamicSlice2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[8,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Reshape(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardWithCollectivePermute) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(param0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[4,4]"), + op::CollectivePermute(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Reshape(), op::Reshape())))); + auto partially_replicated = + AllOf(op::Shape("f32[8,4]"), + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), partially_replicated_init, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardCollectivePermute1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[8,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape()))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::CollectivePermute(op::DynamicSlice( + partially_replicated, op::Reshape(), op::Constant())))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[6,3] parameter(0) + %copy = f32[6,3] copy(param0), + sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[6,3] copy(%copy), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[2,3]"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()), + op::Reshape(), op::Constant()))); + auto slice = + AllOf(op::Shape("f32[2,3]"), + op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice( + partially_replicated_init)), + partially_replicated_init), + _, _)); + auto partially_replicated = + AllOf(op::Shape("f32[3,3]"), + op::Copy(op::Slice(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardHaloExchange1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[6,3] parameter(0) + %copy = f32[6,3] copy(param0), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[6,3] copy(%copy), + sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[3,3]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto slice = AllOf( + op::Shape("f32[4,3]"), + op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init, + op::CollectivePermute(op::Slice( + partially_replicated_init))), + op::Constant()), + _, _)); + auto partially_replicated = + AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_rhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs), + op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[5,1,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithBathGroupCountOutputAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll( + op::Reshape(op::Pad(conv, op::Constant()))))), + op::Shape("f32[3,1,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithBathGroupCountOutputAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + auto conv = + AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll( + op::Reshape(op::Pad(conv, op::Constant()))))), + op::Shape("f32[3,1,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountRHSAlignWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Reshape(), + op::Constant(), op::Constant(), op::Constant())), + op::Shape("f32[3,1,1,1024]")); + auto resharded_rhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs), + op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountLHSAlignWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountAlignOuputWithLHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))), + op::Shape("f32[8,801,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvWithFeatureGroupCountAlignOuputWithRHS) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[16,401,1,1024]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + auto resharded_lhs = AllOf( + op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))), + op::Shape("f32[16,801,1,512]")); + auto conv = AllOf(op::Convolution(resharded_lhs, rhs), + op::Shape("f32[16,801,1,512]")); + EXPECT_THAT(root, + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))), + op::Shape("f32[8,801,1,1024]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,1,1,2]0,1} + %rhs = f32[5,1,1024,1] parameter(1) + %rhs.copy = f32[5,1,1024,1] copy(%rhs), + sharding={devices=[1,1,2,1]0,1} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01oi->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1}, + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[16,801,1,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[5,1,512,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"))); +} + +TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[2,3] parameter(0) + %param1 = f32[2,3,20] parameter(1) + %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7} + %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7} + %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1} + %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1} + %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1} + %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1} + ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20]) + tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose), + sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + // Reshard on copy_add0 only happens on broadcast dims, can be skipped. + auto copy_add0 = + op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_)))); + // Reshard on copy_add1 also happens on non-broadcast dims. + auto copy_add1 = op::Copy( + op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_)))); + // Reshard on copy_reshape only happens on broadcast dims, can be skipped. + auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_)))); + // Reshard on copy_transpose only happens on broadcast dims, can be skipped. + auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_)))); + EXPECT_THAT(root, + op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose)); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionFilterIFOFPartitionedInputPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,112,112,12] parameter(0) + %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[7,7,12,64] parameter(1) + %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs), + sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %conv = f32[128,56,56,64] convolution( + f32[128,112,112,12] %lhs.copy, + f32[7,7,12,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,112,112,6]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Reshape(), op::Reshape())), + op::Shape("f32[7,7,6,32]")); + + EXPECT_THAT( + root, + AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))), + op::Shape("f32[128,56,56,32]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionInputKernelNonContractingDimPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,56,56,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[128,28,28,256]")); + + EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)), + op::Shape("f32[1,1,128,256]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 29def16f89d..0edbd4f2b8d 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -29,12 +29,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -47,6 +50,23 @@ bool HasReplicatedSharding(const HloSharding& sharding) { return sharding.IsReplicated(); } +HloInstruction* CreateConstant(const Shape& shape, Literal value, + SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back(CreateConstant( + ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + CHECK( + ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type())); + auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {})); +} + HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { if (shape.IsTuple()) { std::vector elements; @@ -183,13 +203,17 @@ std::vector MakePartitionOffsets( absl::Span dims) { CHECK(!shape.IsTuple()); - Array2D offset_array( - {sharding.tile_assignment().num_elements(), shape.rank()}); - offset_array.Each([&](int64 i, int64 j, int32* value) { - *value = sharding.TileOffsetForDevice(shape, i)[j]; - }); - auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector> offset_arrays(shape.rank()); + for (int64 i = 0; i < shape.rank(); ++i) { + offset_arrays[i].resize(sharding.tile_assignment().num_elements()); + } + auto shard_shape = MakePartitionedShape(shape, sharding); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + for (int64 i = 0; i < shape.rank(); ++i) { + offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i); + } + }); std::vector offsets; for (int64 i = 0; i < shape.rank(); ++i) { if (sharding.tile_assignment().dim(i) == 1 || @@ -197,11 +221,10 @@ std::vector MakePartitionOffsets( offsets.push_back(b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); } else { + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(offset_arrays[i]))); auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1, 1}), offset_table, - {partition_id, b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(i)))}, - {1, 1})); + ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1})); offsets.push_back(b->AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); } @@ -212,8 +235,11 @@ std::vector MakePartitionOffsets( std::vector MakeTiledPartitionOrdinals( const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { CHECK(!sharding.IsTileMaximal()); - auto table_shape = - ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + auto dimensions = sharding.tile_assignment().dimensions(); + if (sharding.ReplicateOnLastTileDim()) { + dimensions.pop_back(); + } + auto table_shape = ShapeUtil::MakeShape(S32, dimensions); return MakePartitionOffsets(table_shape, sharding, partition_id, b); } @@ -270,12 +296,341 @@ HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( return PadToShape(hlo, padded_base_shape, b); } +absl::optional PartialReplicateReshardCompatibleSharding( + const HloSharding& partial_sharding, const HloSharding& target_sharding) { + if (!partial_sharding.ReplicateOnLastTileDim()) { + return absl::nullopt; + } + int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1; + int64 target_rank = target_sharding.tile_assignment().num_dimensions() - + (target_sharding.ReplicateOnLastTileDim() ? 1 : 0); + if (target_rank != rank) { + return absl::nullopt; + } + + absl::flat_hash_map device_to_replication_group; + partial_sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + int64 gid = 0; + for (int64 i = 0; i < rank; ++i) { + gid *= partial_sharding.tile_assignment().dim(i); + gid += indices[i]; + } + device_to_replication_group[device] = gid; + }); + + // A dimension is expanded when target_tile_size > partial_tile_size and + // target_tile_size % partial_tile_size == 0. + // expand_tile_dims_positions is the index of the expand_dim. + std::vector expand_tile_dims_indices(rank, -1); + // expand_tile_size = target_tile_size / partial_tile_size. + std::vector expand_tile_sizes; + int num_expand_dims = 0; + for (int64 dim = 0; dim < rank; dim++) { + int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim); + int64 target_tile_size = target_sharding.tile_assignment().dim(dim); + if (target_tile_size % partial_tile_size != 0 || + target_tile_size < partial_tile_size) { + return absl::nullopt; + } + + if (target_tile_size > partial_tile_size) { + expand_tile_dims_indices[dim] = num_expand_dims++; + expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size); + } + } + + // Reshape the partial replicate tile_dimensions. + int64 num_target_replication = 1; + if (target_sharding.ReplicateOnLastTileDim()) { + num_target_replication = + target_sharding.tile_assignment().dimensions().back(); + } + auto reshape_dimensions = partial_sharding.tile_assignment().dimensions(); + int64 num_replication = reshape_dimensions.back(); + if (num_replication / num_target_replication != Product(expand_tile_sizes) || + num_replication % num_target_replication != 0) { + return absl::nullopt; + } + + reshape_dimensions.pop_back(); + reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(), + expand_tile_sizes.end()); + + if (target_sharding.ReplicateOnLastTileDim()) { + reshape_dimensions.push_back(num_target_replication); + } + + auto reshape_tile_assignment = partial_sharding.tile_assignment(); + reshape_tile_assignment.Reshape(reshape_dimensions); + + // Transpose. + std::vector perm; + perm.reserve(rank + expand_tile_sizes.size()); + for (int64 dim = 0; dim < rank; dim++) { + perm.emplace_back(dim); + if (expand_tile_dims_indices[dim] > -1) { + perm.emplace_back(expand_tile_dims_indices[dim] + rank); + } + } + auto transpose_sharding = hlo_sharding_util::TransposeSharding( + target_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(reshape_tile_assignment) + : HloSharding::Tile(reshape_tile_assignment), + perm); + + // Reshape to target shape + auto transpose_tile_assignment = transpose_sharding.tile_assignment(); + transpose_tile_assignment.Reshape( + target_sharding.tile_assignment().dimensions()); + + bool groups_matching = true; + target_sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (device_to_replication_group[device] != + device_to_replication_group[transpose_tile_assignment(indices)]) { + groups_matching = false; + } + }); + + if (groups_matching) { + return target_sharding; + } + return target_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(transpose_tile_assignment) + : HloSharding::Tile(transpose_tile_assignment); +} + +absl::optional TileToPartialReplicateHaloExchange( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& replicate_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { + // Source is tile sharding. + auto padded_src_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding); + // Target is partial replicate. + auto padded_dst_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding); + if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) { + return hlo; + } + + auto partition_ordinals = + MakeTiledPartitionOrdinals(dst_sharding, partition_id, b); + + auto result = hlo; + auto hlo_shape = hlo->shape(); + for (auto dim : replicate_dims) { + int64 dst_shard_count = dst_sharding.tile_assignment().dim(dim); + int64 src_per_shard_size = + padded_src_shape.dimensions(dim) / dst_shard_count; + // Calculate per shard size using the sharding to compare if dst_sharding + // needs more padding at the end. + int64 dst_per_shard_size = + padded_dst_shape.dimensions(dim) / dst_shard_count; + + // If src per shard doesn't have redudant data. + if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) { + continue; + } + + // If src_per_shard * replicate_factor > dst_per_shard , need to + // re-distribute the data between each shard using collective permute. For + // example, if dimension size is 6 and shard 4 ways in the src but needs to + // shard 2 ways in the dst. 4 way sharding has 2 element in each shard, + // while 2 way sharding has 3 elements, the last element in the first shard + // will be sliced out. re-distribution is needed. + // + // 1. Calculate left_halo size. + // left-halo size is + // (src_per_shard_size - dst_per_shard_size) * i / replicate_factor + int64 replicate_factor = src_sharding.tile_assignment().dim(dim) / + dst_sharding.tile_assignment().dim(dim); + OffsetCalculation left_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + src_per_shard_size - dst_per_shard_size, 0, replicate_factor)); + + // 2. Calculate right_halo size. + // right-halo size is 0 + OffsetCalculation right_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1)); + + auto concat = result; + // 3. Halo exchange. + auto halo_exchange_result = ExchangeHalo( + result, left_halo_size_function, right_halo_size_function, dim, + src_sharding, collective_ops_creator, next_channel_id, b); + + if (halo_exchange_result.has_value()) { + concat = halo_exchange_result.value(); + } else { + return absl::nullopt; + } + + // 4. Slice the valid result. + // Slice offset is + // (dst_shard_count - i - 1) * + // (src_per_shard_size - dst_per_shard_size) + // i is the index in dst_sharindg. + auto zero_s32 = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + OffsetCalculation start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + dst_per_shard_size - src_per_shard_size, + (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1), + 1)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, + padded_src_shape.dimensions(dim) / + src_sharding.tile_assignment().dim(dim)); + std::vector slice_offsets(concat->shape().rank(), + zero_s32); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinals[dim], b); + result = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + return result; +} + +absl::optional PadFromPartialReplicateShape( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& expand_tile_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { + auto padded_src_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding); + auto padded_dst_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding); + if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) { + return hlo; + } + + auto partition_ordinals = + MakeTiledPartitionOrdinals(src_sharding, partition_id, b); + + HloInstruction* result = hlo; + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + std::vector expand_dims_without_halo_exchange; + // Pad the dimensions needs halo exchange and record the padded dims that + // won't need halo exchange. + for (auto dim : expand_tile_dims) { + int64 src_shard_count = src_sharding.tile_assignment().dim(dim); + int64 src_per_shard_size = + padded_src_shape.dimensions(dim) / src_shard_count; + // Calculate per shard size using the sharding to compare if dst_sharding + // needs more padding at the end. + int64 dst_per_shard_size = + padded_dst_shape.dimensions(dim) / src_shard_count; + + // If dst_sharding doesn't need more padding at the end. + if (src_per_shard_size >= dst_per_shard_size) { + continue; + } + // If src sharding at this dimension is not partitoned, simply pad to + // the desired shape. + if (src_shard_count == 1) { + expand_dims_without_halo_exchange.emplace_back(dim); + continue; + } + + // If dst_padding needs more padding at the end, need to re-distribute the + // data between each shard using collective permute. + // For example, if dimension size is 6 and shard 2 ways in the src but + // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end + // and has 2 elements at each shard, while 2 way sharding has 3 elements + // in each shard, re-distribution is needed. + // + // 1. Calculate left_halo size. + // left-halo size is 0 + OffsetCalculation left_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1)); + + // 2. Calculate right_halo size. + // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S) + OffsetCalculation right_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + dst_per_shard_size - src_per_shard_size, + dst_per_shard_size - src_per_shard_size, 1)); + + auto concat = result; + // 3. Halo exchange. + auto halo_exchange_result = ExchangeHalo( + result, left_halo_size_function, right_halo_size_function, dim, + src_sharding, collective_ops_creator, next_channel_id, b); + + if (halo_exchange_result.has_value()) { + concat = halo_exchange_result.value(); + } else { + return absl::nullopt; + } + + // 4. Pad. + std::vector zero_padding(concat->shape().rank()); + PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding); + pad_config.mutable_dimensions(dim)->set_edge_padding_low(0); + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, src_shard_count - 1); + pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max( + 0LL, padded_dst_shape.dimensions(dim) - + padded_src_shape.dimensions(dim) - max_right_halo_size)); + auto padded_concat_shape = ShapeInference::InferPadShape( + concat->shape(), zero->shape(), pad_config) + .ValueOrDie(); + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, zero, pad_config)); + + // 5. Slice the valid result. + // Slice offset is (D-S) * i + auto zero_s32 = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + OffsetCalculation start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + dst_per_shard_size - src_per_shard_size, 0, 1)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, dst_per_shard_size); + std::vector slice_offsets(concat->shape().rank(), + zero_s32); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinals[dim], b); + result = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + + // Pad other dimensions that won't need halo exchange with a single pad. + if (!expand_dims_without_halo_exchange.empty()) { + std::vector zero_padding(result->shape().rank()); + PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding); + + auto padded_shape = result->shape(); + for (auto dim : expand_dims_without_halo_exchange) { + pad_config.mutable_dimensions(dim)->set_edge_padding_low(0); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim)); + padded_shape.set_dimensions(dim, result->shape().dimensions(dim) + + padded_dst_shape.dimensions(dim) - + padded_src_shape.dimensions(dim)); + } + result = b->AddInstruction( + HloInstruction::CreatePad(padded_shape, result, zero, pad_config)); + } + + return result; +} + absl::optional UniqueTiledDim(const HloSharding& sharding) { if (sharding.IsTileMaximal()) { return absl::nullopt; } int64 dim = -1; - for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + int64 rank = sharding.ReplicateOnLastTileDim() + ? sharding.tile_assignment().num_dimensions() - 1 + : sharding.tile_assignment().num_dimensions(); + for (int64 i = 0; i < rank; ++i) { if (sharding.tile_assignment().dim(i) > 1) { if (dim != -1) { return absl::nullopt; @@ -925,7 +1280,8 @@ GetReshardAllToAllSourceTargetDims(const HloSharding& source, const HloSharding& target) { if (source.IsTileMaximal() || target.IsTileMaximal() || source.tile_assignment().num_dimensions() != - target.tile_assignment().num_dimensions()) { + target.tile_assignment().num_dimensions() || + source.NumTiles() != target.NumTiles()) { return absl::nullopt; } // Record partition count to index for indices that have different partition @@ -1010,61 +1366,112 @@ bool CanReshardWithCollectivePermute(const HloSharding& source, return !source.IsTileMaximal() && !target.IsTileMaximal() && source.tile_assignment().dimensions() == target.tile_assignment().dimensions() && + source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() && source.tile_assignment() != target.tile_assignment(); } GroupedSharding GroupShardingOnDims(const HloSharding& sharding, absl::Span group_dims) { + std::vector group_dim_shards(group_dims.size(), 1); + return GroupShardingOnDims(sharding, group_dims, group_dim_shards); +} + +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims, + absl::Span group_dim_shards) { CHECK(!sharding.IsTileMaximal()); std::vector grouped_tiling_dims = sharding.tile_assignment().dimensions(); std::vector group_dim_sizes(group_dims.size()); for (int64 i = 0; i < group_dims.size(); ++i) { - group_dim_sizes[i] = grouped_tiling_dims[group_dims[i]]; - grouped_tiling_dims[group_dims[i]] = 1; + CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0); + group_dim_sizes[i] = + grouped_tiling_dims[group_dims[i]] / group_dim_shards[i]; + grouped_tiling_dims[group_dims[i]] = group_dim_shards[i]; } + std::vector> device_groups(Product(group_dim_sizes)); sharding.tile_assignment().Each( [&](absl::Span indices, int64 device) { int64 group_id = 0; - for (int64 dim : group_dims) { - group_id *= sharding.tile_assignment().dim(dim); - group_id += indices[dim]; + for (int64 i = 0; i < group_dims.size(); ++i) { + group_id *= sharding.tile_assignment().dim(group_dims[i]) / + group_dim_shards[i]; + group_id += indices[group_dims[i]] / group_dim_shards[i]; } device_groups[group_id].push_back(device); }); - Array grouped_tiling(grouped_tiling_dims); - grouped_tiling.FillIota(0); - return GroupedSharding( + auto grouped = GroupedSharding( std::move(device_groups), std::vector(group_dims.begin(), group_dims.end()), std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), - HloSharding::Tile(grouped_tiling)); + HloSharding::Replicate()); + if (sharding.ReplicateOnLastTileDim()) { + grouped.data_rank--; + } + if (Product(grouped_tiling_dims) == 1 || + (sharding.ReplicateOnLastTileDim() && + Product(grouped_tiling_dims) == grouped_tiling_dims.back())) { + return grouped; + } + if (sharding.ReplicateOnLastTileDim() && grouped_tiling_dims.back() == 1) { + grouped_tiling_dims.pop_back(); + } + Array grouped_tiling(grouped_tiling_dims); + grouped_tiling.FillIota(0); + grouped.sharding = sharding.ReplicateOnLastTileDim() && + grouped_tiling_dims.size() == + sharding.tile_assignment().num_dimensions() + ? HloSharding::PartialTile(grouped_tiling) + : HloSharding::Tile(grouped_tiling); + return grouped; } HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { - CHECK(!grouped_sharding.sharding.IsTileMaximal()); - std::vector tiling_dims = - grouped_sharding.sharding.tile_assignment().dimensions(); + std::vector tiling_dims; + bool partial_sharding = false; + auto grouped_tiling = grouped_sharding.sharding.tile_assignment(); + if (grouped_sharding.sharding.IsTileMaximal()) { + tiling_dims = std::vector(grouped_sharding.data_rank, 1); + if (grouped_sharding.device_groups[0].size() != 1) { + // This is partial sharding. + tiling_dims.push_back(grouped_sharding.device_groups[0].size()); + partial_sharding = true; + } + grouped_tiling = Array(tiling_dims); + grouped_tiling.FillIota(0); + } else { + partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim(); + tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions(); + if (absl::c_linear_search(grouped_sharding.group_dims, + tiling_dims.size())) { + tiling_dims.push_back(1); + grouped_tiling.Reshape(tiling_dims); + partial_sharding = true; + } + } for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { - tiling_dims[grouped_sharding.group_dims[i]] = - grouped_sharding.group_dim_sizes[i]; + int64 dim = grouped_sharding.group_dims[i]; + tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i]; } Array tiling(tiling_dims); - grouped_sharding.sharding.tile_assignment().Each( - [&](absl::Span indices, int64 device) { - std::vector ungrouped_inds(indices.begin(), indices.end()); - for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { - int64 remaining_group_index = g; - for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { - ungrouped_inds[grouped_sharding.group_dims[i]] = - remaining_group_index % grouped_sharding.group_dim_sizes[i]; - remaining_group_index /= grouped_sharding.group_dim_sizes[i]; - } - tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; - } - }); - return HloSharding::Tile(tiling); + grouped_tiling.Each([&](absl::Span indices, int64 device) { + std::vector ungrouped_inds(indices.begin(), indices.end()); + for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { + int64 remaining_group_index = g; + for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { + int64 dim = grouped_sharding.group_dims[i]; + int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i]; + ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) * + grouped_tiling.dim(dim) + + indices[dim]; + remaining_group_index /= groups_in_this_dim; + } + tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; + } + }); + return partial_sharding ? HloSharding::PartialTile(tiling) + : HloSharding::Tile(tiling); } GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, @@ -1118,12 +1525,15 @@ GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, grouped_sharding.device_groups[g], reference.device_groups[ref_g]); } } - if (matching_groups) { + if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) { auto tiles = grouped_sharding.sharding.tile_assignment(); tiles.Each([&](absl::Span indices, int64* device) { *device = original_src_to_ref_permutation[*device]; }); - grouped_sharding.sharding = HloSharding::Tile(tiles); + grouped_sharding.sharding = + grouped_sharding.sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tiles) + : HloSharding::Tile(tiles); } grouped_sharding.device_groups = std::move(reference.device_groups); return grouped_sharding; @@ -1134,6 +1544,9 @@ Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, auto result = original_base_shape; for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { int64 dim = grouped_sharding.group_dims[i]; + if (dim >= original_base_shape.rank()) { + continue; + } int64 groups = grouped_sharding.group_dim_sizes[i]; result.set_dimensions(dim, result.dimensions(dim) / groups); } @@ -1305,49 +1718,6 @@ HloInstruction* PerGroupSliceFromReplicated( shard_shape.dimensions())); } -absl::optional TransposeShardingWithCollapsedDims( - const HloSharding& source, absl::Span src_to_tgt, - absl::Span tgt_to_src) { - if (source.IsTileMaximal()) { - return source; - } - std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); - int64 skipped_tgt_dims = 0; - for (int64 i = 0; i < tgt_to_src.size(); ++i) { - if (tgt_to_src[i] < 0) { - skipped_tgt_dims++; - } else { - tgt_dims_skipping_new[i] = i - skipped_tgt_dims; - } - } - int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); - std::vector perm(src_to_tgt.size()); - for (int64 i = 0; i < src_to_tgt.size(); ++i) { - if (src_to_tgt[i] < 0) { - if (source.tile_assignment().dim(i) > 1) { - return absl::nullopt; - } - perm[src_to_tgt.size() - skipped_src_dims] = i; - skipped_src_dims--; - } else { - perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; - } - } - auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); - if (skipped_tgt_dims == 0) { - return tgt_sharding; - } - auto reshape_tiles = tgt_sharding.tile_assignment(); - std::vector tgt_tiles(tgt_to_src.size(), 1); - for (int64 i = 0; i < tgt_tiles.size(); ++i) { - if (tgt_to_src[i] >= 0) { - tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); - } - } - reshape_tiles.Reshape(tgt_tiles); - return HloSharding::Tile(reshape_tiles); -} - absl::optional ParseReductionComputation( const HloComputation* reduction_comp) { if (reduction_comp->num_parameters() != 2) { @@ -1366,5 +1736,47 @@ absl::optional ParseReductionComputation( return root->opcode(); } +absl::optional> FindMatchingPartitionedDimsForGrouping( + const HloSharding& sharding, + const std::vector>& device_groups) { + if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 || + device_groups[0].size() < 2) { + return absl::nullopt; + } + int64 rank = sharding.tile_assignment().num_dimensions(); + if (sharding.ReplicateOnLastTileDim()) { + rank--; + } + absl::flat_hash_map> device_to_index; + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + device_to_index[device] = + std::vector(index.begin(), index.begin() + rank); + }); + std::vector dims; + int64 group_count = 1; + for (int64 i = 0; i < rank; ++i) { + if (device_to_index[device_groups[0][0]][i] == + device_to_index[device_groups[0][1]][i]) { + dims.push_back(i); + group_count *= sharding.tile_assignment().dim(i); + } + } + if (group_count != device_groups.size()) { + return absl::nullopt; + } + for (const auto& group : device_groups) { + for (int64 i = 1; i < group.size(); ++i) { + if (absl::c_any_of(dims, [&](const int64 dim) { + return device_to_index[group[i]][dim] != + device_to_index[group[0]][dim]; + })) { + return absl::nullopt; + } + } + } + return dims; +} + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 10b630e31ee..4fc193d9622 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -33,6 +33,10 @@ namespace spmd { // Returns true if the given sharding contains any replicated sharding. bool HasReplicatedSharding(const HloSharding& sharding); +// Creates constant value instructions of the given shape. The literal must be a +// scalar shape and is broadcast to the given shape. +HloInstruction* CreateConstant(const Shape& shape, Literal value, + SpmdBuilder* b); // Creates zero value instructions of the given shape. HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); @@ -287,19 +291,25 @@ bool CanReshardWithCollectivePermute(const HloSharding& source, struct GroupedSharding { GroupedSharding(std::vector> device_groups, std::vector group_dims, - std::vector group_dim_sizes, int64 rank, + std::vector group_dim_sizes, int64 data_rank, HloSharding grouped_sharding) : device_groups(std::move(device_groups)), group_dims(std::move(group_dims)), group_dim_sizes(std::move(group_dim_sizes)), + data_rank(data_rank), sharding(std::move(grouped_sharding)) {} std::vector> device_groups; std::vector group_dims; std::vector group_dim_sizes; - int64 rank; + int64 data_rank; HloSharding sharding; }; +// Creates a GroupedSharding for a tiled sharding with group dim shard sizes. +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims, + absl::Span group_dim_shards); + // Creates a GroupedSharding for a tiled sharding. GroupedSharding GroupShardingOnDims(const HloSharding& sharding, absl::Span group_dims); @@ -331,18 +341,50 @@ HloInstruction* PerGroupSliceFromReplicated( absl::Span group_dims, absl::Span group_dim_sizes, SpmdBuilder* b); -// Similar to hlo_sharding_util::TransposeSharding(), but allows removing/adding -// non-partitioned dimensions. In src_to_tgt and tgt_to_src, -1 represents a -// non-existing dimension. -absl::optional TransposeShardingWithCollapsedDims( - const HloSharding& source, absl::Span src_to_tgt, - absl::Span tgt_to_src); - // Returns the opcode if `reduction_comp` represents a simple binary elementwise // computation on the two operands. absl::optional ParseReductionComputation( const HloComputation* reduction_comp); +// Pad the shape from partial replicate shape for `dst_sharding`. +// If dst_sharding needs more padding and per_shard_size increased in +// dst_sharding, halo exchange on the right side is needed. +absl::optional PadFromPartialReplicateShape( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& expand_tile_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); + +// Get the compatible sharding from a partial replicate sharding to a desired +// target tiled sharding. +// Compatible means replicate sharding can transform to the target tile +// dimensions by dynamic slice. +// For example, if partial_sharding is +// {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding +// will be sharding={devices=[2,2]0,2,1,3}. +// If patial replicate sharding is not partial replicate or can't reshard to +// target_tile_dims by dynamic slice, return absl::nullopt. +// If target_sharding is already compatible, returns it. +absl::optional PartialReplicateReshardCompatibleSharding( + const HloSharding& partial_sharding, const HloSharding& target_sharding); + +// Do left halo exchange if all-reduce directly from tile sharding to partial +// replicate sharding will remove useful data from the source. +absl::optional TileToPartialReplicateHaloExchange( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& replicate_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); + +// Finds a list of dimensions that can be grouped on such that it will have the +// specified device groups. Group order and dimension order are ignored. +absl::optional> FindMatchingPartitionedDimsForGrouping( + const HloSharding& sharding, + const std::vector>& device_groups); + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index d54eb9e78c3..4015c69e3e2 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -89,16 +89,23 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { // The last block might be smaller than the block size, // so we will need to pad it if (n % block_size != 0) { - // Pad with zeros + // Pad with identity matrix. auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; - config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); last_blocks = Pad(last_blocks, Zero(builder, shape.element_type()), config); + auto eye = + IdentityMatrix(builder, shape.element_type(), padding, padding); + config = MakeNoPaddingConfig(ndims); + config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n % + block_size); + eye = Pad(eye, Zero(builder, shape.element_type()), config); + last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1); + // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); @@ -121,134 +128,6 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { }); } -XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, - bool conjugate_a, - PrecisionConfig::Precision precision) { - XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - // Input is a batch of square lower triangular square matrices. Its shape is - // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = ShapeUtil::ElementsIn(shape) / - tensorflow::MathUtil::IPow(block_size, 2); - diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); - - // The input must be triangular because we rely on that when doing - // multiplications later on - diag_blocks = Triangle(diag_blocks, /*lower=*/lower); - - // Rescale blocks to be unit triangular, but avoid dividing by - // zero (which can happen if the last block was padded) otherwise it will - // introduce nans which will propagate - auto diags = GetMatrixDiagonal(diag_blocks); - auto ones = FullLike(diags, 1); - diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); - auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); - - // We can now use the fact that for an upper triangular matrix - // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have - // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks - // have been rescaled to be unit triangular, so L22 = L22' = 1. - - // Initialize the output matrix with -1s on the diagonal. We use -1 instead - // of 1 because we cannot do matrix-vector multiplies with variable shapes - // inside of a loop, or do irregularly shaped in-place updates. Hence, - // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the - // entire row i.e. we calculate - // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) - // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. - auto identity = - IdentityMatrix(builder, shape.element_type(), block_size, block_size); - auto neg_identity = -identity; - - // The first or last diagonal element should be set to 1 instead of -1 - // though, since we never update it - auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); - auto output_block = - DynamicUpdateSlice(neg_identity, pos_one, - /*start_indices=*/{start_index, start_index}); - - // Broadcast diag([1, -1, -1, ...]) to every block - XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); - - // Now we construct a loop that performs matrix-vector multiplications - // inverting the blocks one row at a time - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - ShapeUtil::MakeShape(S32, {}), - // The output has the shape of A, with one row updated each iteration. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), - // The input is a loop invariant. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init_i = One(builder, S32); - auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); - - // Construct the loop condition function. - std::unique_ptr condb = - builder->CreateSubBuilder("InvertDiagCond"); - { - auto i = GetTupleElement( - Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, ConstantR0(condb.get(), block_size)); - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function. - std::unique_ptr bodyb = - builder->CreateSubBuilder("InvertDiagBody"); - { - auto input_tuple = - Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); - - auto i = GetTupleElement(input_tuple, 0); - auto body_out = GetTupleElement(input_tuple, 1); - auto body_input = GetTupleElement(input_tuple, 2); - - auto zero = ConstantR0(bodyb.get(), 0); - auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto input_row = - DynamicSlice(body_input, {zero, j, zero}, - /*slice_sizes=*/{num_blocks, 1, block_size}); - - // We want -L21 L11^{-1} - DotDimensionNumbers dnums; - dnums.add_lhs_batch_dimensions(0); - dnums.add_rhs_batch_dimensions(0); - dnums.add_lhs_contracting_dimensions(2); - dnums.add_rhs_contracting_dimensions(1); - PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - - body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); - - auto next_i = i + ScalarLike(i, 1); - Tuple(bodyb.get(), {next_i, body_out, body_input}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto invert_while = While(cond, body, init); - auto inv_diag_blocks = GetTupleElement(invert_while, 1); - - // Undo the scaling - inv_diag_blocks = Div(inv_diag_blocks, diags, - /*broadcast_dimensions=*/{0, 1}); - - // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); - }); -} - XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, @@ -357,10 +236,140 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - bool unit_diagonal, int64 block_size, - PrecisionConfig::Precision precision) { +} // namespace + +XlaOp TriangularSolveExpander::InvertDiagonalBlocks( + XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + auto ones = FullLike(diags, 1); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = + ConstantR0(builder, lower_triangular ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); + + // Broadcast diag([1, -1, -1, ...]) to every block + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + ShapeUtil::MakeShape(S32, {}), + // The output has the shape of A, with one row updated each iteration. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, ConstantR0(condb.get(), block_size)); + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = ConstantR0(bodyb.get(), 0); + auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i; + auto input_row = + DynamicSlice(body_input, {zero, j, zero}, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); + + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); + + auto next_i = i + ScalarLike(i, 1); + Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); + }); +} + +XlaOp TriangularSolveExpander::BuildTriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -422,6 +431,11 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return b; } + // Degenerate case: 1x1 matrices. + if (ShapeUtil::GetDimension(a_shape, -1) == 1) { + return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a)); + } + // TODO(phawkins): consider pushing triangle masking into // InvertDiagonalBlocks. if (unit_diagonal) { @@ -440,8 +454,7 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, - conjugate_a, precision); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision); // We now find the solution using GEMMs auto x = @@ -452,8 +465,6 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } -} // namespace - TriangularSolveExpander::TriangularSolveExpander(int64 block_size) : block_size_(block_size) {} diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h index 362e8557229..3f9e58a3246 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.h +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -35,6 +36,14 @@ class TriangularSolveExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision); + + XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision); + private: // Block size for BuildTriangularSolve const int64 block_size_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index c66f9d96a50..e2b977ad493 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -333,10 +333,10 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( + auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, constant)); + constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index c80123bcd50..785fdecbfa0 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -37,23 +37,15 @@ namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; -// Tries to remove elements in a while loop's tuple that aren't used within the -// loop. -// -// Specifically, if a loop is tuple-shaped, and there exists some element of -// that tuple that is not used by the loop condition and is not used by the loop -// body except to pass it to the next iteration of the loop, then we can remove -// that element from the loop's tuples. -static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - - // Don't try this transformation if the while loop isn't removable, since if - // it succeeds ultimately we're going to have to replace the old while loop - // with a new one. - if (!while_op->parent()->IsSafelyRemovable(while_op)) { - VLOG(2) << "Can't remove dead parameters from non-removable while op."; - return false; - } +// This is a utility function that removes the given tuple indices from the +// while loop init, body, and condition. The final shape returned is still the +// same as before. +static StatusOr RemoveDeadTupleIndices( + HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices) { + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + absl::c_sort(new_to_old_tuple_idx); HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); @@ -62,107 +54,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!while_init->shape().IsTuple()) { - VLOG(2) << "While op's carried value isn't tuple shaped."; - return false; - } - - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple(...) instruction."; - return false; - } - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - // Bail if param0 of while_cond or while_body has users which aren't of type - // get-tuple-element. - for (const HloInstruction* instr : {while_body->parameter_instruction(0), - while_cond->parameter_instruction(0)}) { - for (const HloInstruction* user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToString(print_no_metadata) - << " used by non-GTE instruction " - << user->ToString(print_no_metadata) << " in computation " - << instr->parent()->name(); - return false; - } - } - } - - const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); - if (tuple_size == 0) { - VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " - "empty."; - return false; - } - - absl::flat_hash_set used_tuple_indices; - for (HloComputation* comp : {while_body, while_cond}) { - // The HLO verifier ensures that while_input's shape matches while_init's - // shape, which we verified above is a tuple. - HloInstruction* while_input = comp->parameter_instruction(0); - - for (const HloInstruction* user : while_input->users()) { - // This user doesn't count if it's only used by the while body's root, and - // the root places the tuple element into the same index of the tuple as - // it came from. That just amounts to us carrying the variable through - // the loop. - // - // Careful: HloInstruction::operand_index returns the first index the - // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users().front() == while_body_root && - while_body_root->operand_index(user) == user->tuple_index() && - absl::c_count(while_body_root->operands(), user) == 1) { - continue; - } - - used_tuple_indices.insert(user->tuple_index()); - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If a tuple element is not passed unmodified from the while body's param0 - // through to the while body's root, count that element as "used", since - // removing that element would be observable. - for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.contains(i)) { - continue; - } - - auto* operand = while_body_root->operand(i); - if (operand->opcode() != HloOpcode::kGetTupleElement || - operand->operand(0) != while_body->parameter_instruction(0) || - operand->tuple_index() != i) { - VLOG(2) << "Tuple index " << i - << " is not passed through loop body unmodified."; - used_tuple_indices.insert(i); - - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If we got here, used_tuple_indices.size() < tuple_size, meaning some - // elements of the loop's tuple aren't used by while_body or while_cond. - CHECK_LT(used_tuple_indices.size(), tuple_size); - - VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " - << while_op->ToString(print_no_metadata); - - // Build up maps from the old/new to the new/old tuple indices. - std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), - used_tuple_indices.end()); - absl::c_sort(new_to_old_tuple_idx); - absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; @@ -288,6 +181,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // The tuple simplifier will then simplify this if possible, removing // new_tuple and while_init. std::vector new_tuple_elems; + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { @@ -305,9 +199,293 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); + + return new_while_op; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + absl::flat_hash_set used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users().front() == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + absl::c_count(while_body_root->operands(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.contains(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " + << while_op->ToString(print_no_metadata); + + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + return true; } +// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes +// duplicates by replacing them with tuple_index, followed by a call to +// RemoveDeadTupleIndices. +static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( + HloInstruction* while_op, const int64 tuple_index, + absl::flat_hash_set& duplicates) { + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_init = while_op->mutable_operand(0); + + VLOG(2) << "while_init " << while_init->ToString() << " operands " + << while_init->operand_count(); + VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString() + << " operands " << while_body->root_instruction()->operand_count(); + + // Change the loop body and condition such that uses of the duplicates are + // replaced with the original tuple element. + for (HloComputation* comp : {while_body, while_cond}) { + auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index), + comp->parameter_instruction(0), tuple_index)); + + std::vector instrs_to_replace; + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + duplicates.contains(instr->tuple_index()) && + instr->operand(0) == comp->parameter_instruction(0)) { + instrs_to_replace.push_back(instr); + } + } + + for (auto instr : instrs_to_replace) { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get)); + } + } + + // We know which tuple indices are useful; i.e, those which aren't duplicates. + absl::flat_hash_set used_tuple_indices; + for (int index = 0; index < while_init->shape().tuple_shapes_size(); + ++index) { + if (!duplicates.count(index)) { + used_tuple_indices.insert(index); + } + } + + // Remove the duplicate tuple elements. + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + + return while_op; +} + +// If the while loop init passes the same values to several tuple indices, and +// if the body keeps on passing them through, we can remove the duplicates. +static StatusOr TryRemoveRepeatedWhileTupleIndices( + HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + int index_to_investigate = 0; + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + bool changed = false; + while (index_to_investigate < while_init->shape().tuple_shapes_size()) { + if (!while_init->shape().IsTuple() || + while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto& while_shape = while_init->shape(); + VLOG(2) << "Iterating " << index_to_investigate; + + absl::flat_hash_set duplicates; + auto* pivot_init_elem = while_init->operand(index_to_investigate); + auto* pivot_body_elem = while_body_root->operand(index_to_investigate); + if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement && + pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (pivot_body_elem->tuple_index() != index_to_investigate) { + VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() " + << pivot_body_elem->tuple_index() << " index_to_investigate " + << index_to_investigate; + index_to_investigate++; + continue; + } + } else { + index_to_investigate++; + continue; + } + + // Look from index_to_investigate onwards to see if it is repeated. + for (int64 i = index_to_investigate + 1; + i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (body_elem->opcode() == HloOpcode::kGetTupleElement && + body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (body_elem->tuple_index() != i) { + VLOG(2) << "Mismatch between body_elem->tuple_index() " + << body_elem->tuple_index() << " i " << i; + continue; + } + } else { + continue; + } + + if (pivot_init_elem == init_elem) { + VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem " + << pivot_init_elem->ToString(); + VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem " + << pivot_body_elem->ToString(); + duplicates.insert(i); + } + } + + // If duplicates are found, call the helper to remove them. + if (!duplicates.empty()) { + VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init " + << pivot_init_elem->ToString(); + TF_ASSIGN_OR_RETURN(while_op, + TryRemoveRepeatedWhileTupleIndicesHelper( + while_op, index_to_investigate, duplicates)); + changed = true; + VLOG(2) << "Changed while_op " << while_op->ToString() + << " while_op operand count " << while_op->operand_count(); + // Update the while loop variables so we can continue looking for + // duplicates of a different index. + while_init = while_op->mutable_operand(0); + while_cond = while_op->while_condition(); + while_body = while_op->while_body(); + while_body_root = while_body->root_instruction(); + } + index_to_investigate++; + } + + return changed; +} + // Removes each loop parameter (i.e. member of the while loop tuple) that is a // constant and is the same in the while loop body and the while loop init. static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { @@ -1048,6 +1226,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); changed |= result; + if (result) { // Don't continue simplifying after successfully removing the while loop // -- that would result in use-after-free nastiness. @@ -1067,6 +1246,12 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { // successful, meaning that `while_op` is no longer valid after one of these // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op)); + changed |= result; + if (result) { + continue; + } + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { @@ -1074,6 +1259,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; if (result) { continue; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d715fb3857a..c93cb5dc347 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { .ValueOrDie()); } +TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) { + const string hlo_string = R"( + HloModule SwappingTupleElements + + SwappingTupleElements.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(loop_var), index=0 + get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1 + get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2 + y = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y, + s32[] get-tuple-element.2) + } + + SwappingTupleElements.always_true { + param = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(param), index=0 + get-tuple-element.1 = s32[] get-tuple-element(param), index=1 + ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT + } + + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x) + ROOT while = (s32[], s32[], s32[]) while(tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index b4982f1d8e4..64c9635f335 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -61,6 +61,10 @@ class ShapeLayout { // Returns the shape (with layouts). const Shape& shape() const { return shape_; } + // Clear dynamic dimensions of this module. Pretending the module creates + // static results. Useful in inspecting full outputs when testing. + void ClearDynamicShape() { shape_.clear_dynamic_dimensions(); } + // Checks that a layout is set for the shape, and returns a reference to the // layout directly on the shape. Shape must not be a tuple. const Layout& layout() const; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 73bb3327784..b1c96e9becf 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -70,6 +70,8 @@ struct IndexTableEntry { template class ShapeTreeIterator; +template +class ShapeTreeLeafIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) @@ -158,23 +160,25 @@ class ShapeTree { using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; + using leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::iterator, + std::pair>; + using const_leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_leaf_iterator = std::reverse_iterator; + using const_reverse_leaf_iterator = + std::reverse_iterator; + // begin/end for iterating over all nodes. - iterator begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - iterator end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); - } + iterator begin() { return iterator(&nodes_, nodes_.begin()); } + iterator end() { return iterator(&nodes_, nodes_.end()); } const_iterator begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - const_iterator end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, nodes_.begin()); } + const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); } // rbegin/rend for iterating over all nodes in reverse. reverse_iterator rbegin() { return reverse_iterator(end()); } @@ -188,37 +192,33 @@ class ShapeTree { // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). - iterator leaf_begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); + leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); } + leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); } + const_leaf_iterator leaf_begin() const { + return const_leaf_iterator(&nodes_, nodes_.begin()); } - iterator leaf_end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); + const_leaf_iterator leaf_end() const { + return const_leaf_iterator(&nodes_, nodes_.end()); } // range-based iterator for leaf_begin()/leaf_end(). - tensorflow::gtl::iterator_range leaves() { + tensorflow::gtl::iterator_range leaves() { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - tensorflow::gtl::iterator_range leaves() const { + tensorflow::gtl::iterator_range leaves() const { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } - reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } - const_reverse_iterator leaf_rbegin() const { - return const_reverse_iterator(leaf_end()); + reverse_leaf_iterator leaf_rbegin() { + return reverse_leaf_iterator(leaf_end()); } - const_reverse_iterator leaf_rend() const { - return const_reverse_iterator(leaf_begin()); + reverse_leaf_iterator leaf_rend() { + return reverse_leaf_iterator(leaf_begin()); + } + const_reverse_leaf_iterator leaf_rbegin() const { + return const_reverse_leaf_iterator(leaf_end()); + } + const_reverse_leaf_iterator leaf_rend() const { + return const_reverse_leaf_iterator(leaf_begin()); } // Returns an iterator pointing to the given ShapeIndex. @@ -226,12 +226,12 @@ class ShapeTree { iterator find(ShapeIndexView index) { Node* element = Lookup(index); auto element_iter = nodes_.begin() + (element - &nodes_[0]); - return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return iterator(&nodes_, element_iter); } const_iterator find(ShapeIndexView index) const { - Node* element = Lookup(index); + const Node* element = Lookup(index); auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); - return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, element_iter); } // Returns the number of leaf nodes in the tree. @@ -343,21 +343,11 @@ template class ShapeTreeIterator : public std::iterator { public: - ShapeTreeIterator(ContainerType* nodes, IteratorType node, - bool iterate_leaves_only) - : nodes_(nodes), - node_(std::move(node)), - iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) {} ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } return *this; } ShapeTreeIterator operator++(int) { @@ -368,9 +358,6 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { - --node_; - } return *this; } ShapeTreeIterator operator--(int) { @@ -385,14 +372,66 @@ class ShapeTreeIterator bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - ValueType& operator*() { return node_->data; } - ValueType* operator->() { return &node_->data; } + ValueType& operator*() const { return node_->data; } + ValueType* operator->() const { return &node_->data; } + + private: + ContainerType* nodes_; + IteratorType node_; +}; + +// Internal iterator that performs a pre-order walk of the leaves. This is cheap +// to copy. The iterator value_type is equivalent to a std::pair&, +// similar to std::map. +template +class ShapeTreeLeafIterator + : public std::iterator { + public: + ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) { + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + } + + ShapeTreeLeafIterator& operator++() { + ++node_; + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + return *this; + } + ShapeTreeLeafIterator operator++(int) { + auto i = *this; + ++(*this); + return i; + } + + ShapeTreeLeafIterator& operator--() { + --node_; + while (node_ > nodes_->begin() && !node_->is_leaf) { + --node_; + } + return *this; + } + ShapeTreeLeafIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + + bool operator==(const ShapeTreeLeafIterator& other) const { + return node_ == other.node_; + } + bool operator!=(const ShapeTreeLeafIterator& other) const { + return node_ != other.node_; + } + ValueType& operator*() const { return node_->data; } + ValueType* operator->() const { return &node_->data; } private: ContainerType* nodes_; IteratorType node_; - // True if we should not include interior nodes in our walk. - const bool iterate_leaves_only_; }; template @@ -648,7 +687,9 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), - ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ShapeUtil::GetSubshape(other.shape(), source_base_index))) + << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs " + << ShapeUtil::GetSubshape(other.shape(), source_base_index); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 2b6c484bc4f..c294355e269 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -485,6 +485,30 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { })); } +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, Find) { + ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'const ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, ConstFind) { + const ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 02fcaafd19d..0833919b124 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -783,9 +783,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = original; - new_shape.set_element_type(type); - return new_shape; + if (original.IsTuple()) { + std::vector new_operands; + new_operands.reserve(original.tuple_shapes_size()); + for (const Shape& operand : original.tuple_shapes()) { + new_operands.push_back(ChangeElementType(operand, type)); + } + return MakeTupleShape(new_operands); + } else { + Shape new_shape = original; + new_shape.set_element_type(type); + return new_shape; + } } /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 927f9d14883..d9110ed1f35 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -380,12 +380,7 @@ xla_test( name = "conv_depthwise_backprop_filter_test", timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], - # these backends do not natively handle batch group counts. - disabled_backends = [ - "gpu", - "cpu", - ], - shard_count = 6, + shard_count = 40, deps = [ ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", @@ -2088,6 +2083,31 @@ xla_test( ], ) +xla_test( + name = "dynamism_inference_test", + srcs = ["dynamism_inference_test.cc"], + deps = [ + ":test_macros_header", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], @@ -2674,5 +2694,6 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tensor_float_32_utils", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index a956b85a940..ef4ce24a839 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1203,6 +1203,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN}); + EqTotalOrder(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, true, false}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -1222,6 +1232,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + // For portability, need to represent NAN using the following call. + // The C++ standard does not specify if quiet_NaN() sets the sign bit of + // its result. The call to std::fabs will ensure that it is not set. + auto nan = std::fabs(std::numeric_limits::quiet_NaN()); + auto lhs = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, nan, 6.0f, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, nan, -nan}); + GeTotalOrder(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, true, true, true, false, true}, + {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index 856ea7c9b44..f78083fe2af 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -61,7 +61,7 @@ class BufferDonationTest : public HloTestBase { absl::Span argument_literals, absl::Span donate_arguments, absl::Span expected_runtime_aliasing, - const Literal& expected) { + const Literal& expected, std::string expected_failure = "") { // Create a copy of the output shape because the HLO module is std::moved // into the compiler and may be deallocated. const Shape output_shape = hlo_module->result_shape(); @@ -123,10 +123,19 @@ class BufferDonationTest : public HloTestBase { ExecutionInput(std::move(owned_buffers), argument_literal.shape())); } - TF_ASSERT_OK_AND_ASSIGN( - ExecutionOutput output, + StatusOr output_status = executable->ExecuteAsyncOnStream(&service_run_options, std::move(args), - /*hlo_execution_profile=*/nullptr)); + /*hlo_execution_profile=*/nullptr); + if (!expected_failure.empty()) { + ASSERT_FALSE(output_status.ok()); + ASSERT_TRUE(absl::StrContains(output_status.status().error_message(), + expected_failure)) + << "got: \n" + << output_status.status().error_message() << " \nvs want\n" + << expected_failure; + return; + } + ExecutionOutput output = output_status.ConsumeValueOrDie(); se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer(); LOG(INFO) << "result allocation = " << result_root_buffer.opaque() @@ -303,5 +312,37 @@ ENTRY entry { #endif } +TEST_F(BufferDonationTest, TestMustAliasNotDonated) { + HloModuleConfig config; + + StatusOr> module = + ParseAndReturnVerifiedModule(R"( +HloModule module + +ENTRY entry { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = (f32[], f32[]) tuple(a, b) +} + )", + config); + + TF_ASSERT_OK(module->get()->input_output_alias_config().SetUpAlias( + {0}, 0, {}, HloInputOutputAliasConfig::kMustAlias)); + + std::vector args; + args.push_back(LiteralUtil::CreateR0(0.1)); + args.push_back(LiteralUtil::CreateR0(0.2)); + Literal expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(0.1), LiteralUtil::CreateR0(0.2)}); + +#ifndef XLA_TEST_BACKEND_INTERPRETER + RunAndCheck(std::move(*module), args, + /*donate_arguments=*/{false, false}, {true, false}, expected, + "An input was configured to be must-alias at " + "compile time but not donated at runtime:"); +#endif +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index e7f5ca5ed8e..616b404b425 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" namespace xla { namespace { @@ -181,6 +182,8 @@ class RandomCholeskyTest public ::testing::WithParamInterface {}; XLA_TEST_P(RandomCholeskyTest, Random) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); XlaBuilder builder(TestName()); auto test_params = GetParam(); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 0e99ede5d01..6acbb7a9cf0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -605,7 +605,7 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, : LiteralSlice(literal)); } -std::unique_ptr +StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, const Literal& literal, const string& name, @@ -637,15 +637,14 @@ Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( return literal.Clone(); } -std::unique_ptr +StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { Literal param_literal = MaybeConvertLiteralToBfloat16(literal); - std::unique_ptr data = - client_->TransferToServer(param_literal, device_handle) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(auto data, + client_->TransferToServer(param_literal, device_handle)); *data_handle = Parameter(builder, parameter_number, param_literal.shape(), name); return data; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 17bb70bdb42..3c9e37b8fa4 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -270,14 +270,14 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - std::unique_ptr CreateParameterAndTransferLiteral( + StatusOr> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - std::unique_ptr CreateParameterAndTransferLiteral( + StatusOr> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle); diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index ff7e7955876..4a7070a32f3 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -45,13 +45,20 @@ class BatchGroupedConvolution2DTest public ::testing::WithParamInterface< ::testing::tuple> {}; -static std::vector GetConv2DTestCases() { +class BatchGroupedConvolution2DDepthTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases( + bool use_depth_multiplier) { std::vector config_set; std::vector> config_options = { - {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, - {16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4}, - {256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}}; + {129, 10, 3, 2}, {4, 3, 3, 258}, {8, 4, 2, 128}, + {8, 3, 2, 256}, {256, 7, 5, 4}, {128, 6, 6, 4}, + {32, 5, 2, 129}, {16, 4, 3, 2}, {16, 3, 2, 64}}; + int64 counter = 2; for (auto option : config_options) { int64 feature = option[3]; int64 activation_size = option[1]; @@ -65,10 +72,16 @@ static std::vector GetConv2DTestCases() { config.activation_dims = {batch, activation_size, activation_size, feature}; - config.kernel_dims = {batch, kernel_size, kernel_size, feature}; - + const int64 depthwise_multiplier = use_depth_multiplier ? counter++ : 1; + config.kernel_dims = {batch, kernel_size, kernel_size, + feature * depthwise_multiplier}; + // Don't let the counter grow too much, else the compute demand will grow. + if (counter == 4) { + counter = 2; + } int64 output_space_size = 3 + activation_size - kernel_size; - config.output_dims = {output_space_size, output_space_size, feature, 1}; + config.output_dims = {output_space_size, output_space_size, + feature * depthwise_multiplier, 1}; config.activation_and_kernel_layout = {0, 3, 1, 2}; config.output_layout = {2, 3, 0, 1}; @@ -123,11 +136,13 @@ string BatchGroupedConvolution2DTestDataToString( } string BuildHloTextBatchGroupedConvolution2D( - const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) { + const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16, + bool scheduled = false) { const string data_type = GetFloatDataType(use_bfloat16); + const string scheduled_tag = scheduled ? ",is_scheduled=true" : ""; return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv %s ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -137,7 +152,7 @@ string BuildHloTextBatchGroupedConvolution2D( batch_group_count=%d } )", - data_type, absl::StrJoin(spec.activation_dims, ","), + scheduled_tag, data_type, absl::StrJoin(spec.activation_dims, ","), absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, absl::StrJoin(spec.kernel_dims, ","), absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, @@ -161,23 +176,26 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { } #endif - const string hlo_text = - BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); + const string hlo_text = BuildHloTextBatchGroupedConvolution2D( + spec, use_bfloat16, /*scheduled=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01})); } INSTANTIATE_TEST_CASE_P( BatchGroupedConvolution2DTestWithRandomIndices, BatchGroupedConvolution2DTest, - ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), - ::testing::Bool()), + ::testing::Combine( + ::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/false)), + ::testing::Bool()), + BatchGroupedConvolution2DTestDataToString); + +INSTANTIATE_TEST_CASE_P( + BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices, + BatchGroupedConvolution2DTest, + ::testing::Combine( + ::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)), + ::testing::Bool()), BatchGroupedConvolution2DTestDataToString); } // namespace diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 60ba27b2050..e06e2972f1c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -69,12 +69,14 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaBuilder builder(TestName()); XlaOp param; - auto param_data = CreateParameterAndTransferLiteral( - 0, - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), - LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), - "arg0", &builder, ¶m); + TF_ASSERT_OK_AND_ASSIGN( + auto param_data, + CreateParameterAndTransferLiteral( + 0, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), + "arg0", &builder, ¶m)); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc new file mode 100644 index 00000000000..a7e032448e0 --- /dev/null +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -0,0 +1,242 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/match.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class DynamismInferenceTest : public ::testing::Test { + public: + explicit DynamismInferenceTest(se::Platform* platform = nullptr) + : platform_(platform) {} + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(se::Platform* platform, ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + + StatusOr ComputeDynamismLiteral(Client* client, XlaOp operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, + builder->BuildDynamicInferenceGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); + return std::move(computed); + } + + StatusOr ComputeDynamismScalar(Client* client, XlaOp operand, + XlaBuilder* builder, + ShapeIndex index = {}) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand, + builder, nullptr)); + return literal.Get({}, index); + } + + se::Platform* platform_; +}; + +TEST_F(DynamismInferenceTest, ScalarInt32Literal) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = ConstantR0(&b, 42); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A constant is not dynamic. + EXPECT_EQ(value.ValueOrDie(), false); + } +} + +TEST_F(DynamismInferenceTest, TupleSimple) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto tuple = Tuple(&b, {c, p}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true); + } +} + +TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto tuple = Tuple(&b, {c, p}); + auto gte0 = GetTupleElement(tuple, 0); + auto gte1 = GetTupleElement(tuple, 1); + auto tuple_2 = Tuple(&b, {gte0, gte1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, PredValueUsedTwice) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + auto pred = Eq(c, p); + auto result = Select(pred, p, c); + EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), + false); + } +} + +TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto concat = ConcatScalars(&b, {c, p}); + auto slice0 = SliceInDim(concat, 0, 1, 1, 0); + auto reshape0 = Reshape(slice0, {}); + auto slice1 = SliceInDim(concat, 1, 2, 1, 0); + auto reshape1 = Reshape(slice1, {}); + auto tuple_2 = Tuple(&b, {reshape0, reshape1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, ParameterIsDynamic) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A parameter is considered dynamic. + EXPECT_EQ(value.ValueOrDie(), true); + } +} + +TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto neg0 = Neg(c); + auto neg1 = Neg(p); + auto tuple_2 = Tuple(&b, {neg0, neg1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + // Static value + static value = static + auto add1 = Add(c, c); + // Dynamic value + dynamic value = dynamic + auto add2 = Add(p, c); + auto tuple_2 = Tuple(&b, {add1, add2}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, GetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + // param = Param([<=2, 3]) + // get_dimension_size(param, 0) is dynamic + // get_dimension_size(param, 1) is static + auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), + "p0"); + + auto gds0 = GetDimensionSize(p, 0); + auto gds1 = GetDimensionSize(p, 1); + auto tuple_2 = Tuple(&b, {gds0, gds1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + true); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + false); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc index 09c91d4be14..dca8e31e792 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc @@ -123,8 +123,16 @@ BINARY_TEST_16BIT(Min, { }) // TODO(bixia): Pow fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), - { Run(AddEmptyBroadcastDimension(Pow), std::pow); }) +BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), { + // See b/162664705. + known_incorrect_fn_ = [](int64 val) { + Eigen::bfloat16 f; + uint16_t val_16 = val; + memcpy(&f, &val_16, 2); + return std::isnan(f); + }; + Run(AddEmptyBroadcastDimension(Pow), std::pow); +}) // TODO(bixia): Atan2 fails with bfloat16 on CPU. BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2), diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc index 14d3b343b6c..c6feedf9e7f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test_f32_f64.cc @@ -114,6 +114,10 @@ BINARY_TEST_FLOAT_32(Min, { // // TODO(bixia): Need to investigate the failure on CPU and file bugs. BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; auto host_abs_complex = [](float x, float y) { return std::abs(std::complex(x, y)); }; @@ -198,6 +202,10 @@ BINARY_TEST_FLOAT_64(Min, { // TODO(bixia): Need to investigate the failure on CPU and file bugs. BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; auto host_abs_complex = [](double x, double y) { return std::abs(std::complex(x, y)); }; diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc index b361bf94a6d..6a638d2106f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc @@ -97,6 +97,10 @@ using ExhaustiveC128UnaryTest = ExhaustiveComplexUnaryTestBase; // TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), { + // TODO(timshen): see b/162664705. + known_incorrect_fn_ = [this](int64 val) { + return std::isnan(this->ConvertValue(val)); + }; Run(Log, [](complex64 x) { return std::log(x); }); }) diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 0fd5f191db0..0f8a4c1e273 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -711,6 +711,24 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) { + const string hlo_text = R"( +HloModule GatherFromScalar + +ENTRY main { + operand = f32[1,1,1] parameter(0) + indices = s32[2,3,50] parameter(1) + ROOT gather = f32[1,2,50] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={0,1}, + start_index_map={1,0,2}, + index_vector_dim=1, + slice_sizes={1,1,1} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0})); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; // Disabled on interpreter since ExecuteAsyncOnStream is not supported. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d0b6e5f80ed..663e7d81006 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -230,6 +230,19 @@ StatusOr> HloTestBase::ExecuteReplicated( device_assignment); } +StatusOr> HloTestBase::ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + int64 num_replicas, bool run_hlo_passes) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = true; + return test_runner_.ExecuteReplicated( + executable_provider, argument_count_provider, argument_provider, options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 17c2a55ba5b..fc680e39682 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -169,6 +169,13 @@ class HloTestBase : public ManifestCheckingTest { int64 num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads); + // Same as above, but allows passing different programs for replicas. + StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + int64 num_replicas, bool run_hlo_passes); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index b209669715e..7e5b699d5e2 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -365,8 +365,9 @@ XLA_TEST_P(ReduceWindowTest, R4UnitWindow) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -423,8 +424,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 1; int stride = 8; @@ -444,8 +446,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 3; int stride = 1; @@ -465,8 +468,9 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; - auto input_data = CreateParameterAndTransferLiteral( - 0, input_literal, "parameter", &builder_, &input); + TF_ASSERT_OK_AND_ASSIGN( + auto input_data, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder_, &input)); int win_len = 8; int stride = 5; @@ -631,8 +635,9 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", - &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_arg, + CreateParameterAndTransferLiteral( + 0, input_literal, "p0", &b, ¶meter)); std::vector> padding(4); for (int i = 0; i < 4; ++i) { @@ -1243,7 +1248,9 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); + TF_ASSERT_OK(CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, + ¶meter) + .status()); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1443,8 +1450,9 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) { Literal input_literal = LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = - CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input_arg, CreateParameterAndTransferLiteral(0, input_literal, "p0", + &b, ¶meter)); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 298136002e9..890156cc650 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -57,8 +57,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -70,8 +71,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -83,8 +85,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); @@ -99,8 +102,9 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral( + 0, input_literal, "parameter", &builder, ¶meter)); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); @@ -115,8 +119,9 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", + &builder, ¶meter)); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); @@ -130,8 +135,9 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -144,8 +150,9 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, param0_literal, "param0", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -157,8 +164,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -170,8 +178,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -183,8 +192,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -196,8 +206,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); @@ -211,8 +222,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = @@ -226,8 +238,9 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); @@ -241,8 +254,9 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); @@ -258,8 +272,9 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); @@ -274,8 +289,9 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, @@ -288,8 +304,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); @@ -304,8 +321,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = @@ -318,8 +336,9 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); @@ -334,8 +353,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); @@ -349,8 +369,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); @@ -365,8 +386,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, @@ -391,8 +413,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( @@ -406,8 +429,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, @@ -426,8 +450,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( @@ -441,8 +466,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); auto expected_literal = LiteralUtil::CreateR2({{10, 20, 30}, @@ -461,8 +487,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( @@ -494,8 +521,9 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, @@ -519,8 +547,9 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); @@ -542,8 +571,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) { input_literal.Set(zeros, 83.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &b, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); @@ -556,8 +586,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, - ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), @@ -568,8 +599,9 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, - ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter)); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); @@ -604,8 +636,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -639,8 +672,9 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off @@ -666,8 +700,9 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off @@ -694,8 +729,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); @@ -713,8 +749,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); @@ -733,8 +770,9 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); @@ -759,8 +797,9 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -793,8 +832,9 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); @@ -808,8 +848,9 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN( + auto input, CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", + &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -840,8 +881,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -867,8 +909,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -894,8 +937,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -922,8 +966,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); @@ -949,8 +994,9 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", - &builder, ¶meter); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + CreateParameterAndTransferLiteral( + 0, input_literal, "input", &builder, ¶meter)); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/xla/union_find.h similarity index 100% rename from tensorflow/compiler/jit/union_find.h rename to tensorflow/compiler/xla/union_find.h diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1fbce96625b..4034e5fdd27 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -31,10 +31,10 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numbers.h" diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e8b6105d3fe..d334f879c3e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -627,6 +627,11 @@ message OpSharding { // applied, this is inferred from the instruction this sharding gets attached // to. repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; } // Describes the replica groups in a cross replica op (e.g., all-reduce and diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 6a704be4adb..172a970d207 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -96,6 +96,7 @@ tf_gen_op_libs( "xrt_execute_op", ], deps = [ + "//tensorflow/compiler/jit:flags", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index a4be39b96c6..321d7409103 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -20,6 +21,11 @@ limitations under the License. namespace tensorflow { +static bool Initialized = [] { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + return true; +}(); + REGISTER_OP("XRTAllocate") .Input("allocation: string") .Output("handle: int64") diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 161a0a95856..6da5c43ce82 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,7 +72,6 @@ load( "if_ios", "if_mobile", "if_not_windows", - "if_tpu", "tf_android_core_proto_headers", "tf_cc_test", "tf_cc_test_mkl", @@ -117,6 +116,7 @@ load( "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", + "tf_tpu_dependencies", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -318,7 +318,6 @@ alias( cc_library( name = "lib_proto_parsing", hdrs = [ - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_proto_parsing_headers", "//tensorflow/core/lib/strings:legacy_lib_proto_parsing_headers", "//tensorflow/core/platform:lib_proto_parsing_hdrs", @@ -328,7 +327,6 @@ cc_library( ":platform_base", "@com_google_absl//absl/strings", "@double_conversion//:double-conversion", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/lib/core:errors", "//tensorflow/core/lib/core:stringpiece", "//tensorflow/core/lib/core:status", @@ -353,6 +351,7 @@ cc_library( cc_library( name = "lib", hdrs = [ + # TODO(rmlarsen): Remove bfloat16.h once dependency in third_party/swift is updated. "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_headers", @@ -489,6 +488,7 @@ tf_cuda_library( "//tensorflow/core/framework:register_types_traits.h", "//tensorflow/core/framework:resource_mgr.h", "//tensorflow/core/framework:resource_op_kernel.h", + "//tensorflow/core/framework:rng_alg.h", "//tensorflow/core/framework:selective_registration.h", "//tensorflow/core/framework:session_state.h", "//tensorflow/core/framework:shape_inference.h", @@ -582,7 +582,6 @@ cc_library( "//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:tensor_types.h", "//tensorflow/core/framework:type_traits.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/platform:framework_lite_hdrs", "//tensorflow/core/platform/default:integral_types.h", "//tensorflow/core/platform/default:logging.h", @@ -593,7 +592,6 @@ cc_library( "@nsync//:nsync_cpp", ] + [ "//third_party/eigen3", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:thread_annotations", @@ -629,6 +627,7 @@ tf_gen_op_libs( "io_ops", "linalg_ops", "list_ops", + "map_ops", "lookup_ops", "manip_ops", "math_ops", @@ -654,6 +653,7 @@ tf_gen_op_libs( "spectral_ops", "state_ops", "stateless_random_ops", + "stateless_random_ops_v2", "summary_ops", "training_ops", ], @@ -672,6 +672,8 @@ tf_gen_op_libs( ":lib", ":protos_all_cc", # TODO(b/162630222): remove this dependency. + "//tensorflow/c/kernels:histogram_summary_op_lib", + "//tensorflow/c/kernels:merge_summary_op_lib", "//tensorflow/c/kernels:summary_op_lib", ], ) @@ -843,6 +845,7 @@ cc_library( ":io_ops_op_lib", ":linalg_ops_op_lib", ":list_ops_op_lib", + ":map_ops_op_lib", ":logging_ops_op_lib", ":lookup_ops_op_lib", ":manip_ops_op_lib", @@ -870,11 +873,14 @@ cc_library( ":spectral_ops_op_lib", ":state_ops_op_lib", ":stateless_random_ops_op_lib", + ":stateless_random_ops_v2_op_lib", ":string_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", + "//tensorflow/c/kernels:histogram_summary_op_lib", + "//tensorflow/c/kernels:merge_summary_op_lib", "//tensorflow/c/kernels:summary_op_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", ] + if_chromiumos( @@ -983,8 +989,10 @@ cc_library( name = "all_kernels_impl", visibility = [":__subpackages__"], deps = [ - "//tensorflow/c/kernels:summary_op", "//tensorflow/c/kernels:bitcast_op", + "//tensorflow/c/kernels:histogram_summary_op", + "//tensorflow/c/kernels:merge_summary_op", + "//tensorflow/c/kernels:summary_op", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", "//tensorflow/core/kernels:batch_kernels", @@ -1008,9 +1016,8 @@ cc_library( "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", - "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", - "//tensorflow/core/kernels/linalg:linalg", + "//tensorflow/core/kernels:isotonic_regression_op", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", @@ -1044,32 +1051,34 @@ cc_library( "//tensorflow/core/kernels:summary_kernels", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", + "//tensorflow/core/kernels/linalg:linalg", + "//tensorflow/core/kernels/image:image", "//tensorflow/core/kernels/sparse:kernels", ] + if_not_windows([ "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ - "//tensorflow/core/kernels:mkl_aggregate_ops", - "//tensorflow/core/kernels:mkl_concat_op", - "//tensorflow/core/kernels:mkl_dequantize_op", - "//tensorflow/core/kernels:mkl_conv_op", - "//tensorflow/core/kernels:mkl_cwise_ops_common", - "//tensorflow/core/kernels:mkl_fused_batch_norm_op", - "//tensorflow/core/kernels:mkl_identity_op", - "//tensorflow/core/kernels:mkl_input_conversion_op", - "//tensorflow/core/kernels:mkl_lrn_op", - "//tensorflow/core/kernels:mkl_pooling_ops", - "//tensorflow/core/kernels:mkl_qmatmul_op", - "//tensorflow/core/kernels:mkl_requantize_ops", - "//tensorflow/core/kernels:mkl_quantize_op", - "//tensorflow/core/kernels:mkl_relu_op", - "//tensorflow/core/kernels:mkl_reshape_op", - "//tensorflow/core/kernels:mkl_slice_op", - "//tensorflow/core/kernels:mkl_softmax_op", - "//tensorflow/core/kernels:mkl_transpose_op", - "//tensorflow/core/kernels:mkl_batch_matmul_op", - "//tensorflow/core/kernels:mkl_matmul_op", - "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:mkl_tmp_bf16_ops", + "//tensorflow/core/kernels/mkl:mkl_aggregate_ops", + "//tensorflow/core/kernels/mkl:mkl_concat_op", + "//tensorflow/core/kernels/mkl:mkl_dequantize_op", + "//tensorflow/core/kernels/mkl:mkl_conv_op", + "//tensorflow/core/kernels/mkl:mkl_cwise_ops_common", + "//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels/mkl:mkl_identity_op", + "//tensorflow/core/kernels/mkl:mkl_input_conversion_op", + "//tensorflow/core/kernels/mkl:mkl_lrn_op", + "//tensorflow/core/kernels/mkl:mkl_pooling_ops", + "//tensorflow/core/kernels/mkl:mkl_qmatmul_op", + "//tensorflow/core/kernels/mkl:mkl_requantize_ops", + "//tensorflow/core/kernels/mkl:mkl_quantize_op", + "//tensorflow/core/kernels/mkl:mkl_relu_op", + "//tensorflow/core/kernels/mkl:mkl_reshape_op", + "//tensorflow/core/kernels/mkl:mkl_slice_op", + "//tensorflow/core/kernels/mkl:mkl_softmax_op", + "//tensorflow/core/kernels/mkl:mkl_transpose_op", + "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_tfconv_op", + "//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops", ]) + if_cuda_or_rocm([ "//tensorflow/core/kernels:cudnn_rnn_kernels", ]) + if_cuda([ @@ -1080,9 +1089,7 @@ cc_library( ]) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", - ]) + if_tpu([ - "//tensorflow/core/tpu/kernels", - ]), + ]) + tf_tpu_dependencies(), ) cc_library( @@ -1107,6 +1114,8 @@ cc_library( # these also dynamically loading. "//tensorflow/core/kernels:dataset_ops", # Depends on grappler "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + "//tensorflow/core/kernels:map_kernels", + "//tensorflow/core/kernels:tensor_map", ], ) @@ -1158,7 +1167,7 @@ cc_library( ) # Test support library needed for higher-level (TensorFlow-specific) tests -cc_library( +tf_cuda_library( name = "testlib", testonly = 1, srcs = [ @@ -1251,7 +1260,6 @@ filegroup( "//tensorflow/core/example:mobile_srcs_no_runtime", "//tensorflow/core/framework:attr_value_proto_text_srcs", "//tensorflow/core/framework:mobile_srcs_no_runtime", - "//tensorflow/core/lib/bfloat16:mobile_srcs_no_runtime", "//tensorflow/core/lib/core:mobile_srcs_no_runtime", "//tensorflow/core/lib/gtl:mobile_srcs_no_runtime", "//tensorflow/core/lib/hash:mobile_srcs_no_runtime", @@ -1290,6 +1298,7 @@ filegroup( "//tensorflow/core/graph:mobile_srcs_only_runtime", "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", + "//tensorflow/core/nccl:mobile_srcs", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/public:mobile_srcs_only_runtime", "//tensorflow/core/util/sparse:mobile_srcs_only_runtime", @@ -1689,7 +1698,6 @@ filegroup( "//tensorflow/core/framework:resource_handle.h", "//tensorflow/core/platform:legacy_lib_internal_headers", "//tensorflow/core/platform:lib_internal_private_hdrs", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_all_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", @@ -1806,7 +1814,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/lib/core:arena", "//tensorflow/core/lib/core:bitmap", "//tensorflow/core/lib/core:blocking_counter", @@ -1887,6 +1894,7 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/lib/strings:stringprintf", "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:bfloat16", "//tensorflow/core/platform:base64", "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:casts", @@ -1973,6 +1981,7 @@ cc_library( ":lib", ":lib_internal", "//tensorflow/core/platform:gif", + "@com_google_absl//absl/strings", ], ) @@ -2005,6 +2014,11 @@ alias( actual = "//tensorflow/core/lib/png:png_io", ) +alias( + name = "portable_png_internal", + actual = "//tensorflow/core/lib/png:png_io", +) + alias( name = "android_png_internal", actual = "//tensorflow/core/lib/png:png_io", @@ -2013,7 +2027,6 @@ alias( cc_library( name = "tflite_portable_logging", hdrs = [ - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/platform:tflite_portable_logging_hdrs", "//tensorflow/core/platform/default:integral_types.h", "//tensorflow/core/platform/default:logging.h", @@ -2034,8 +2047,8 @@ cc_library( ) cc_library( - name = "android_jpeg_internal", - srcs = if_android([ + name = "portable_jpeg_internal", + srcs = if_mobile([ "lib/jpeg/jpeg_handle.cc", "lib/jpeg/jpeg_mem.cc", "//tensorflow/core/platform:jpeg_hdrs", @@ -2043,14 +2056,13 @@ cc_library( hdrs = [ "lib/jpeg/jpeg_handle.h", "lib/jpeg/jpeg_mem.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", "//tensorflow/core/platform:jpeg_internal_hdrs", "//tensorflow/core/platform/default:integral_types.h", "//tensorflow/core/platform/default:logging.h", ], copts = tf_copts(), - linkopts = ["-ldl"], + linkopts = if_android(["-ldl"]), deps = [ ":core_stringpiece", "//tensorflow/core/platform:dynamic_annotations", @@ -2063,14 +2075,13 @@ cc_library( ) cc_library( - name = "android_gif_internal", - srcs = if_android([ + name = "portable_gif_internal", + srcs = if_mobile([ "lib/gif/gif_io.cc", "//tensorflow/core/platform:gif_hdrs", ]), hdrs = [ "lib/gif/gif_io.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", "//tensorflow/core/lib/gtl:legacy_android_gif_internal_headers", "//tensorflow/core/platform:gif_internal_hdrs", @@ -2078,21 +2089,27 @@ cc_library( "//tensorflow/core/platform/default:logging.h", ], copts = tf_copts(), - linkopts = ["-ldl"], + linkopts = if_android(["-ldl"]), deps = [ - "//tensorflow/core/lib/strings:numbers", - "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:gif", "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:numbers", - "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) +alias( + name = "android_jpeg_internal", + actual = ":portable_jpeg_internal", +) + +alias( + name = "android_gif_internal", + actual = ":portable_gif_internal", +) + alias( name = "error_codes_proto_impl", actual = "//tensorflow/core/protobuf:error_codes_proto_impl", @@ -2693,27 +2710,27 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", ] + if_mkl([ - "//tensorflow/core/kernels:mkl_aggregate_ops", - "//tensorflow/core/kernels:mkl_batch_matmul_op", - "//tensorflow/core/kernels:mkl_concat_op", - "//tensorflow/core/kernels:mkl_conv_op", - "//tensorflow/core/kernels:mkl_cwise_ops_common", - "//tensorflow/core/kernels:mkl_dequantize_op", - "//tensorflow/core/kernels:mkl_fused_batch_norm_op", - "//tensorflow/core/kernels:mkl_identity_op", - "//tensorflow/core/kernels:mkl_input_conversion_op", - "//tensorflow/core/kernels:mkl_lrn_op", - "//tensorflow/core/kernels:mkl_matmul_op", - "//tensorflow/core/kernels:mkl_pooling_ops", - "//tensorflow/core/kernels:mkl_qmatmul_op", - "//tensorflow/core/kernels:mkl_quantize_op", - "//tensorflow/core/kernels:mkl_relu_op", - "//tensorflow/core/kernels:mkl_reshape_op", - "//tensorflow/core/kernels:mkl_slice_op", - "//tensorflow/core/kernels:mkl_softmax_op", - "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:mkl_transpose_op", - "//tensorflow/core/kernels:mkl_tmp_bf16_ops", + "//tensorflow/core/kernels/mkl:mkl_aggregate_ops", + "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_concat_op", + "//tensorflow/core/kernels/mkl:mkl_conv_op", + "//tensorflow/core/kernels/mkl:mkl_cwise_ops_common", + "//tensorflow/core/kernels/mkl:mkl_dequantize_op", + "//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels/mkl:mkl_identity_op", + "//tensorflow/core/kernels/mkl:mkl_input_conversion_op", + "//tensorflow/core/kernels/mkl:mkl_lrn_op", + "//tensorflow/core/kernels/mkl:mkl_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_pooling_ops", + "//tensorflow/core/kernels/mkl:mkl_qmatmul_op", + "//tensorflow/core/kernels/mkl:mkl_quantize_op", + "//tensorflow/core/kernels/mkl:mkl_relu_op", + "//tensorflow/core/kernels/mkl:mkl_reshape_op", + "//tensorflow/core/kernels/mkl:mkl_slice_op", + "//tensorflow/core/kernels/mkl:mkl_softmax_op", + "//tensorflow/core/kernels/mkl:mkl_tfconv_op", + "//tensorflow/core/kernels/mkl:mkl_transpose_op", + "//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops", ]), ) @@ -2965,6 +2982,8 @@ filegroup( srcs = [ # PNG data "//tensorflow/core/lib/png:testdata", + "//tensorflow/core/lib/ssim:testdata", + "//tensorflow/core/lib/psnr:testdata", # JPEG data "lib/jpeg/testdata/jpeg_merge_test1.jpg", "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg", @@ -2986,44 +3005,14 @@ filegroup( # GIF data with optimization "lib/gif/testdata/optimized.gif", # BMP data - "lib/bmp/testdata/lena.bmp", - "lib/bmp/testdata/rgb_small.bmp", - "lib/bmp/testdata/rgb_small_255.bmp", - "lib/bmp/testdata/rgba_small.bmp", - "lib/bmp/testdata/rgba_small_255.bmp", - "lib/bmp/testdata/grayscale_small.bmp", - "lib/bmp/testdata/grayscale_small_3channels.bmp", - "lib/bmp/testdata/grayscale_small_4channels.bmp", - # SSIM, PSNR data - "lib/ssim/testdata/checkerboard1.png", - "lib/ssim/testdata/checkerboard2.png", - "lib/ssim/testdata/checkerboard3.png", - "lib/psnr/testdata/cat_q20.jpg", - "lib/psnr/testdata/cat_q72.jpg", - "lib/psnr/testdata/cat_q95.jpg", + "//tensorflow/core/lib/bmp:bmp_testdata", ], visibility = ["//visibility:public"], ) -filegroup( +alias( name = "lmdb_testdata", - testonly = 1, - srcs = [ - # A simple key-value store: - # 0 : 'b' - # 1 : 'b' - # ... - # 9 : 'b' - # Which is then overwritten with: - # 0 : 'a' - # 1 : 'b' - # ... - # 9 : 'j' - "lib/lmdb/testdata/data.mdb", - # LMDB, being a memory-mapped database, uses a different file format on - # big-endian systems. - "lib/lmdb/testdata/data_bigendian.mdb", - ], + actual = "//tensorflow/core/lib/lmdb:lmdb_testdata", visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index dfa0b78cb17..e72f74e26e4 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -37,9 +37,9 @@ filegroup( visibility = ["//tensorflow:internal"], ) -filegroup( +alias( name = "java_api_def", - srcs = glob(["java_api/*"]), + actual = "//tensorflow/core/api_def/java_api:java_api_def", visibility = ["//tensorflow:internal"], ) diff --git a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt index 2184b644b23..dc018aec4aa 100644 --- a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt @@ -1,4 +1,11 @@ op { graph_op_name: "Acos" summary: "Computes acos of x element-wise." + description: <